diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5b8012db37d57df3acf45cafd5e27ea6ed704da Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/code_template.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/code_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..919de9bc9724a34e3267464c97a224bd065a0ab5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/code_template.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/context.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7d82b6e433d4e2adfdffb6c9ae42494fcc276dd Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/context.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ba44b86fc15430910107e8776b7037595c19105 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a58be48a6aff3ab69a7f57f5af6016546e786a2b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6a91270024fd08cc21763c6db4f871667cf04d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb531b8630189e8585a7cc55d010625cdc6a9114 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..801dabedec5b19bb90698cd77ac64b1e02f0c91c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..285ad001c7470d3196670b0c9aff33ac087b723c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/local.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/local.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa2daf040b770f5d3aba6afc57cf6ca71301cd26 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/local.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/native_function_generation.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/native_function_generation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87dea78b2edd86ced2781c6cf16e06df604aea2f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/native_function_generation.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f81d35e97f89fbb34326a7c0158f4d68a60718ff Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/yaml_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/yaml_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..324933689cd08a755260f30d554f4c1967942f35 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/__pycache__/yaml_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac003756c10d2ef73c67c4af598ee7d5777a94e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcf982c7e75d310b7dee7292501d3f2056dc7ab7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/fallback_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/fallback_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f78cc85e22676edfa5ec90e5c6f204f5bfaea10a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/aoti/fallback_ops.py @@ -0,0 +1,194 @@ +# Be extra careful when you edit this file, because it affects AOTInductor ABI compatibility. See +# https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 +# for details. +# +# The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py. +# +# Generally speaking, it is ok to add a new op to the list, but you need to run +# `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files. +# But it is NOT ok to remove an existing fallback op from the list, since that will break +# some existing AOTInductor-compiled models. +# +# A fallback op version defaults to 1. If you want to extend an existing fallback op by adding +# a new argument with a default value, while it is fine in the Python world, it will be BC-breaking +# when generating C shim. Thus you need to bump up the version number of that fallback op by +# updating the entry in the inductor_fallback_ops list, adding a new version number with a list +# of new arguments, and then run `python torchgen/gen.py --update-aoti-c-shim` to regenerate. + +inductor_fallback_ops: dict[str, dict[str, list[str]]] = { + "aten._adaptive_avg_pool2d_backward.default": {}, + "aten._adaptive_avg_pool2d.default": {}, + "aten._adaptive_avg_pool3d_backward.default": {}, + "aten._adaptive_avg_pool3d.default": {}, + "aten._addmm_activation.default": {}, + "aten._cdist_backward.default": {}, + "aten._cdist_forward.default": {}, + "aten._cudnn_rnn.default": {}, + "aten._dyn_quant_matmul_4bit.default": {}, + "aten._dyn_quant_pack_4bit_weight.default": {}, + "aten._efficient_attention_backward.default": {}, + "aten._efficient_attention_forward.default": {}, + "aten._efficientzerotensor.default": {}, + "aten._embedding_bag_dense_backward.default": {}, + "aten._embedding_bag_forward_only.default": {}, + "aten._embedding_bag_per_sample_weights_backward.default": {}, + "aten._embedding_bag.default": {}, + "aten._fft_c2c.default": {}, + "aten._fft_r2c.default": {}, + "aten._flash_attention_backward.default": {}, + "aten._flash_attention_forward.default": {}, + "aten._fused_moving_avg_obs_fq_helper_functional.default": {}, + "aten._fused_moving_avg_obs_fq_helper.default": {}, + "aten._fused_rms_norm.default": {}, + "aten._histogramdd_from_bin_cts.default": {}, + "aten._int_mm.out": {}, + "aten._pdist_backward.default": {}, + "aten._pdist_forward.default": {}, + "aten._scaled_dot_product_attention_math_for_mps.default": {}, + "aten._scaled_dot_product_cudnn_attention_backward.default": {}, + "aten._scaled_dot_product_cudnn_attention.default": {}, + "aten._scaled_dot_product_efficient_attention_backward.default": {}, + "aten._scaled_dot_product_efficient_attention.default": {}, + "aten._scaled_dot_product_flash_attention_backward.default": {}, + "aten._scaled_dot_product_flash_attention_for_cpu_backward.default": {}, + "aten._scaled_dot_product_flash_attention_for_cpu.default": {}, + "aten._scaled_dot_product_flash_attention.default": {}, + "aten._scaled_dot_product_fused_attention_overrideable_backward.default": {}, + "aten._scaled_dot_product_fused_attention_overrideable.default": {}, + "aten._scaled_mm.default": {}, + "aten._scaled_grouped_mm.default": {}, + "aten._scaled_mm.out": {}, + "aten._segment_reduce_backward.default": {}, + "aten._thnn_fused_lstm_cell.default": {}, + "aten._to_sparse.default": {}, + "aten._trilinear.default": {}, + "aten._weight_int4pack_mm.default": {}, + "aten._weight_int8pack_mm.default": {}, + "aten.abs.default": {}, + "aten.adaptive_max_pool2d_backward.default": {}, + "aten.adaptive_max_pool2d.default": {}, + "aten.adaptive_max_pool3d_backward.default": {}, + "aten.adaptive_max_pool3d.default": {}, + "aten.add.Scalar": {}, + "aten.add.Tensor": {}, + "aten.addbmm.default": {}, + "aten.addmm.out": {}, + "aten.addmv.default": {}, + "aten.angle.default": {}, + "aten.avg_pool2d_backward.default": {}, + "aten.avg_pool2d.default": {}, + "aten.avg_pool3d_backward.default": {}, + "aten.avg_pool3d.default": {}, + "aten.baddbmm.out": {}, + "aten.bernoulli_.float": {}, + "aten.bernoulli_.Tensor": {}, + "aten.bmm.out": {}, + "aten.bucketize.Tensor": {}, + "aten.cat.default": {}, + "aten.cholesky_inverse.default": {}, + "aten.cholesky_solve.default": {}, + "aten.convolution_backward.default": {}, + "aten.convolution.default": {}, + "aten.cummax.default": {}, + "aten.cummin.default": {}, + "aten.cumprod.default": {}, + "aten.cumsum.default": {}, + "aten.exponential.default": {}, + "aten.fill_.Scalar": {}, + "aten.fractional_max_pool2d_backward.default": {}, + "aten.fractional_max_pool2d.default": {}, + "aten.fractional_max_pool3d_backward.default": {}, + "aten.fractional_max_pool3d.default": {}, + "aten.gcd.default": {}, + "aten.geqrf.default": {}, + "aten.grid_sampler_2d_backward.default": {}, + "aten.hann_window.default": {}, + "aten.histc.default": {}, + "aten.histogram.bin_ct": {}, + "aten.index_put.default": {}, + "aten.index_reduce.default": {}, + "aten.index.Tensor": {}, + "aten.kthvalue.default": {}, + "aten.logcumsumexp.default": {}, + "aten.lu_unpack.default": {}, + "aten.masked_scatter_backward.default": {}, + "aten.masked_scatter.default": {}, + "aten.masked_select.default": {}, + "aten.max_pool2d_with_indices_backward.default": {}, + "aten.max_pool2d_with_indices.default": {}, + "aten.max_pool3d_with_indices_backward.default": {}, + "aten.max_pool3d_with_indices.default": {}, + "aten.max_unpool2d.default": {}, + "aten.max_unpool3d.default": {}, + "aten.median.default": {}, + "aten.mm.out": {}, + "aten.mode.default": {}, + "aten.mul.Scalar": {}, + "aten.mul.Tensor": {}, + "aten.nanmedian.default": {}, + "aten.narrow.default": {}, + "aten.native_dropout.default": {}, + "aten.nonzero.default": {}, + "aten.normal_functional.default": {}, + "aten.ormqr.default": {}, + "aten.pad.default": {}, + "aten.permute.default": {}, + "aten.polar.default": {}, + "aten.pow.Scalar": {}, + "aten.pow.Tensor_Scalar": {}, + "aten.pow.Tensor_Tensor": {}, + "aten.rand.default": {}, + "aten.rand.generator": {}, + "aten.randint.default": {}, + "aten.randint.generator": {}, + "aten.randint.low_out": {}, + "aten.randint.low": {}, + "aten.randn.default": {}, + "aten.randn.generator": {}, + "aten.randperm.default": {}, + "aten.repeat_interleave.Tensor": {}, + "aten.replication_pad1d_backward.default": {}, + "aten.replication_pad2d_backward.default": {}, + "aten.reshape.default": {}, + "aten.resize_.default": {}, + "aten.resize_as_.default": {}, + "aten.scatter_reduce.two_out": {}, + "aten.scatter.src_out": {}, + "aten.scatter.value_out": {}, + "aten.searchsorted.Scalar": {}, + "aten.searchsorted.Tensor": {}, + "aten.segment_reduce.default": {}, + "aten.set_.source_Tensor": {}, + "aten.slice.Tensor": {}, + "aten.soft_margin_loss_backward.default": {}, + "aten.sort.default": {}, + "aten.sort.stable": {}, + "aten.squeeze.dim": {}, + "aten.to_sparse.default": {}, + "aten.topk.default": {}, + "aten.triangular_solve.default": {}, + "aten.uniform.default": {}, + "aten.upsample_bicubic2d_backward.default": {}, + "aten.upsample_linear1d_backward.default": {}, + "aten.upsample_trilinear3d_backward.default": {}, + "aten.view_as_complex.default": {}, + "aten.view_as_real.default": {}, + "aten.view.dtype": {}, + "aten._weight_int4pack_mm_with_scales_and_zeros.default": {}, +} + +# `python torchgen/gen.py --update-aoti-c-shim` will automatically generate +# c_shim_aten.{h/cpp} based on the list below. +# Operators in this list are intended to be used in torch/csrc/stable/ops.h +# Unlike other c_shims, operators in this file do not bypass the dispatcher. +# The same BC rules apply as inductor_fallback_ops. +aten_shimified_ops: dict[str, dict[str, list[str]]] = { + "aten.fill_.Scalar": {}, + "aten.pad.default": {}, + "aten.narrow.default": {}, + "aten.amax.default": {}, + "aten.new_empty.default": {}, + "aten.new_zeros.default": {}, + "aten.full.default": {}, + "aten.subtract.Tensor": {}, +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ef060b1f17c4db6158312b1af00df3e1f9da916 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/autograd.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/autograd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7178b50d39f0e2e8460868a9429638c39069e643 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/autograd.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/cpp.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/cpp.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e44213288ec15de3db5be3368276cfd6fd97df Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/cpp.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/dispatcher.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/dispatcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5ef585444027a555816d9d99d168de59c707d53 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/dispatcher.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/functionalization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/functionalization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f19d0d7104bf47b769b4064de3f8fec79791298f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/functionalization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/lazy.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/lazy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b137eb95463c377a5462fd548238862507a8837 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/lazy.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/meta.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/meta.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27c02299de079eaa8650e8c710b02b40a0d7416 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/meta.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/native.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/native.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ca6640d225722eaea16533e6dc3dad9798c5454 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/native.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/python.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/python.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdbae62296bfd37ebdd12b32c3b059d8c6c011bc Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/python.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/structured.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/structured.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd2e0b15b34266f860279b0363167c7e918ec777 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/structured.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/translate.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/translate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c063cfc20c6f7aa1e34b5829d5e8edaf682e86a6 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/translate.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/ufunc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/ufunc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd1886a10886667d8c9473f61e611d2844cfa326 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/ufunc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/unboxing.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/unboxing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af852cc3057fe8c9ab7fdcffb12d614487a0dfb7 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/__pycache__/unboxing.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/autograd.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..96e192d3a48a9c72202e28117409ed99bc7377f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/autograd.py @@ -0,0 +1,874 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import cast, TYPE_CHECKING + +from torchgen import local +from torchgen.api import cpp +from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT +from torchgen.model import ( + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + NativeFunctionsViewGroup, + SchemaKind, + Type, +) +from torchgen.utils import IDENT_REGEX + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# Represents a saved attribute involved in backward calculation. +# Note that it can be a derived property of an input argument, e.g.: +# we could save `other.scalar_type()` instead of the entire `other` tensor. +@dataclass(frozen=True) +class SavedAttribute: + # The NamedCType holds the updated name and cpp type of the attribute + # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type` + nctype: NamedCType + + # The expression to read the derived property at save time, e.g.: + # `other.scalar_type()`. + expr: str + + +# Represents a backward formula that calculates derivatives for one +# or more tensors. +@dataclass(frozen=True) +class Derivative: + # The formula string (legit C++ expression). + # Note that expressions against input arguments have been replaced with the + # corresponding saved attributes. + # E.g.: + # raw formula: `mul_tensor_backward(grad, self, other.scalar_type())` + # here: `mul_tensor_backward(grad, self, other_scalar_type)` + formula: str + + # The formula string before input argument replacement + original_formula: str + + # Names of the arguments for which this formula calculates derivatives. + var_names: tuple[str, ...] + + # Saved inputs that are referenced by the formula. + saved_inputs: tuple[SavedAttribute, ...] + + # Saved outputs that are referenced by the formula. + saved_outputs: tuple[SavedAttribute, ...] + + # Gradients that are referenced by name in the formula. + named_gradients: set[str] + + +# Represents a forward formula that calculates forward derivatives +# for one tensor. +@dataclass(frozen=True) +class ForwardDerivative: + # The formula string (legit C++ expression). + # Note that special keywords such as "linear" or "element_wise" have been + # replaced by the automatically generated formula. + formula: str + + # Name of the output arguments for which this formula calculates forward + # derivatives + var_names: tuple[str, ...] + + # Type of the output arguments for which this formula calculates forward + # derivatives + var_types: tuple[Type, ...] + + # Inputs for which the forward derivatives are required for this formula + required_inputs_fw_grad: tuple[str, ...] | None + + # Inputs for which the primal is required for this formula + required_inputs_primal: tuple[str, ...] | None + + # Flag to specify if this formula requires the original value of self + # This is only used by inplace operations + required_original_self_value: bool + + # If this formula is specified in derivatives.yaml or if we are reusing the + # out of place formula for inplace + is_reusing_outplace_formula: bool + + +# Represents differentiability info for a NativeFunction. +@dataclass(frozen=True) +class DifferentiabilityInfo: + # The base name read from derivatives.yaml. + name: str + + # The matching native function. + # + # There can be multiple NativeFunction having the same base name: + # - different overloads with different types of input arguments; + # - in-place/out/functional variants of the same function; + # + # We first use the schema string (under the 'name' key) in derivatives.yaml + # to find the NativeFunction having the same schema string. + # Then we find the in-place/out/functional variants of the matching function. + # Among these variants, we choose the one having the same name as the + # derivatives.yaml entry. If there is no exact match, then we choose the + # in-place variant. + # TODO: maybe the logic to search for all variants is no longer necessary? + func: NativeFunction + + # The name of the generated autograd function. + # It's set only if we will calculate a derivative, i.e. + # 'args_with_derivatives' is not empty. + op: str | None + + # The derivatives formulae for this function. + # Note that the length of this sequence is the number of differentiable inputs + derivatives: Sequence[Derivative] + + # The forward derivatives formulae for this function. + # Note that the length of this sequence is the number of differentiable outputs + forward_derivatives: Sequence[ForwardDerivative] + + # The union of 'saved_inputs' of all 'derivatives'. + all_saved_inputs: Sequence[SavedAttribute] + + # The union of 'saved_outputs' of all 'derivatives'. + all_saved_outputs: Sequence[SavedAttribute] + + # All named gradients that are available for use, in the same + # order as in the grads vector. + available_named_gradients: Sequence[str] + + # The named gradients that are used in any of the derivatives. + # Invariant: all(name in available_named_gradients for name in used_named_gradients) + used_named_gradients: set[str] + + # The function's input arguments for which it calculates derivatives. + # It's the union of 'var_names' of all 'derivatives', sorted by the + # argument order in the function schema. + args_with_derivatives: Sequence[Binding] + + # Names of arguments whose derivative formula is 'non_differentiable'. + non_differentiable_arg_names: Sequence[str] + + # Raw data read from derivatives.yaml. + output_differentiability: list[bool] | None + + # output_differentiability in derivatives.yaml can be a list of + # conditions that express if the output is differentiable. In this case, + # the number of conditions must match the number of outputs + # (NB: we only support one condition right now). + # output_differentiability gets populated with True for each condition, + # while output_differentiability_conditions gets populated with the conditions + output_differentiability_conditions: list[str] | None + + @property + def has_derivatives(self) -> bool: + return len(self.args_with_derivatives) > 0 + + # Generates a new DifferentiabilityInfo using the exact same set of derivative information, + # but with a new operator name. + # This is used when generating "copy" variants of view ops, + # which are able to use the exact same derivative formula as the original view op + # See Note [Codegen'd {view}_copy Operators] + def create_view_copy_from_view_derivative( + self, g: NativeFunctionsViewGroup + ) -> DifferentiabilityInfo | None: + if g.view_copy is None: + return None + f = g.view_copy + + name_split_by_period = self.name.split(".", maxsplit=2) + # Append a "_copy" to the base name of the operator (but keep the overload name the same) + view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join( + name_split_by_period[1:] + ) + view_copy_op_name = None if self.op is None else f"{self.op}_copy" + + return DifferentiabilityInfo( + # Use the "_copy" version of name/func/op + name=view_copy_name, + func=f, + op=view_copy_op_name, + # But keep all derivative info the same + derivatives=self.derivatives, + forward_derivatives=self.forward_derivatives, + all_saved_inputs=self.all_saved_inputs, + all_saved_outputs=self.all_saved_outputs, + available_named_gradients=self.available_named_gradients, + used_named_gradients=self.used_named_gradients, + args_with_derivatives=self.args_with_derivatives, + non_differentiable_arg_names=self.non_differentiable_arg_names, + output_differentiability=self.output_differentiability, + output_differentiability_conditions=self.output_differentiability_conditions, + ) + + +def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool: + if info is None: + return False + for derivative in info.derivatives: + formula = derivative.formula + if re.search(IDENT_REGEX.format(ident), formula): + return True + return False + + +def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool: + return uses_ident(info, "retain_variables") + + +def uses_single_grad(info: DifferentiabilityInfo | None) -> bool: + return uses_ident(info, "grad") + + +# Represents a differentiable `Argument`. +# How is it different from the `Argument` type? +# - It's processed Arguments which are differentiable and only used in the +# context of the autograd codegen; +# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument; +@dataclass(frozen=True) +class DifferentiableInput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + + +# Represents a differentiable `Return`. +# How it it different from the `Return` type? +# - The name in `Return` is optional. Here it is always populated using the same +# `cpp.return_names()` method. +# TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? +# - It's processed Returns which are differentiable, in compliance with the +# `output_differentiability` field defined in derivatives.yaml (if specified), +# and are only used in the context of the autograd codegen; +@dataclass(frozen=True) +class DifferentiableOutput: + name: str + type: Type + + # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove. + cpp_type: str + + +@dataclass(frozen=True) +class NativeFunctionWithDifferentiabilityInfo: + func: NativeFunction + info: dict[str, DifferentiabilityInfo] | None + fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None + + +# TODO: Update comment below since it is out of date. +def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str: + """How are we going to call the underlying implementation of a + declaration? There are two strategies: + - use_derived: we want to call the implementation on CPUDoubleType + (or a similar, derived Type instance). Because these derived + instances deal in Tensors, not Variables (it's a completely different + object, so it doesn't dispatch back to VariableType), code on + this dispatch path needs to wrap/unwrap tensors. If the + derived implementation takes and returns tensors, the + implementation is usually differentiable (although we also use + the derived dispatch path for non-differentiable functions + that we still want to dispatch on the derived Type instance; + e.g., size()) + - use_type: we want to call the implementation on Type, because + it is implemented concretely, and the functions it invokes will + get dispatched back to VariableType (which will ensure that they + are differentiable.) + """ + # fn is derived as long as any of its per-key differentiability infos + # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType + # and ADInplaceOrViewType. We want to generate these functions as long as a + # derivative is defined for ANY dispatch key. + if fn.func.is_abstract or ( + fn.info is not None and any(info.has_derivatives for info in fn.info.values()) + ): + # If the function is abstract (not implemented on at::Type), we must + # call the implementation on the derived type with unpacked tensors. + + # If the function has a derivative specified and is concrete, we could + # call either implementation. We prefer the calling the derived + # type's implementation with unpacked tensors because it is more + # performant in some cases: any internal calls to other ATen functions + # won't have the history tracked. + + # If the function has a type dispatched argument (i.e. is a factory), + # we prefer calling the derived type's implementation both because it is + # more performant and to ensure factory functions return tensors with _version + # of 0 (probably not strictly necessary, but nice to have to keeps versions simple + # to understand. + + return "use_derived" + else: + # If the function is concrete (we don't have to override it) and we + # didn't declare it in derivatives.yaml, we'll assume that it is + # actually implemented out of differentiable functions. (This + # assumption might not hold, but then you'll see gradcheck fail.) + return "use_type" + + +def is_foreach_func(f: NativeFunction) -> bool: + return f.func.name.name.base.startswith("_foreach_") + + +# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind +# is functional for their backward derivatives (and forward derivatives in the future), i.e., +# they would find such one in `functional_info_by_signature`. There however are some exceptions: +_foreach_with_inplace_ref = {"_foreach_zero_"} +_foreach_with_tensor_overload = { + "_foreach_add.Tensor", + "_foreach_mul.Tensor", + "_foreach_div.Tensor", +} +# The following do not support the alpha kwarg, which the nonforeach versions support. +_skip_argument_len_check = { + "_foreach_add.Scalar", + "_foreach_add_.Scalar", + "_foreach_add.ScalarList", + "_foreach_add_.ScalarList", + "_foreach_sub.Scalar", + "_foreach_sub_.Scalar", + "_foreach_sub.ScalarList", + "_foreach_sub_.ScalarList", +} + + +# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function +# reference to generate derivatives. +def is_reference_for_foreach( + f: NativeFunction, + function_schema: FunctionSchema, +) -> bool: + return ( + f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base + and ( + not function_schema.name.name.inplace + or str(f.func.name) in _foreach_with_inplace_ref + ) + and ( + str(f.func.name) in _skip_argument_len_check + or len(f.func.arguments.flat_non_out) + == len(function_schema.arguments.flat_non_out) + ) + and all( + ref_arg.type in (arg.type, getattr(arg.type, "elem", None)) + for arg, ref_arg in zip( + f.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ) + ) + + +# TODO(crcrpar): Avoid hard coding "Default" ideally. +def gen_foreach_derivativeinfo( + foreach_function: NativeFunction, + functional_info_by_signature: dict[ + FunctionSchema, dict[str, DifferentiabilityInfo] + ], + non_functional_info_by_signature: dict[ + FunctionSchema, dict[str, DifferentiabilityInfo] + ], + dispatch_key: str = "Default", +) -> tuple[DifferentiabilityInfo | None, bool]: + """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. + + The second return value indicates whether the info is generated in this function. + """ + ref_diff_info: DifferentiabilityInfo | None = None + + for function_schema, diff_info in functional_info_by_signature.items(): + if not is_reference_for_foreach(foreach_function, function_schema): + continue + ref_diff_info = diff_info[dispatch_key] + if ref_diff_info is not None: + break + # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature + # while the info of `zero_` is in non_functional_info_by_signature + if ( + ref_diff_info is None + and foreach_function.func.kind() == SchemaKind.inplace + and str(foreach_function.func.name) in _foreach_with_inplace_ref + ): + for function_schema, diff_info in non_functional_info_by_signature.items(): + if not is_reference_for_foreach(foreach_function, function_schema): + continue + ref_diff_info = diff_info[dispatch_key] + if ref_diff_info is not None: + break + if ref_diff_info is None: + return None, False + + # non out-place uses the existing Derivative. + if foreach_function.func.kind() == SchemaKind.inplace: + return ref_diff_info, False + + map_refarg2foreacharg, map_name2arg = {}, {} + for i, (arg, ref_arg) in enumerate( + zip( + foreach_function.func.arguments.flat_non_out, + function_schema.arguments.flat_non_out, + ) + ): + map_refarg2foreacharg[ref_arg.name] = arg.name + map_name2arg[arg.name] = arg + + all_saved_inputs, all_saved_outputs, all_var_names = [], [], [] + modified_derivative_formulas = [] + for i, derivative in enumerate(ref_diff_info.derivatives): + modified_formula = derivative.formula.replace("grad", "grads[i]").replace( + "result", "result[i]" + ) + saved_inputs, saved_outputs = [], [] + # note(crcrpar): This context seems necessary to call `cpp.argument_type` + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + for ref_input in derivative.saved_inputs: + ref_input_jit_name = ref_input.expr.split(".")[0] + mapped_name = map_refarg2foreacharg[ref_input_jit_name] + if isinstance(map_name2arg[mapped_name].type, ListType): + mapped_expr = mapped_name + "[i]" + else: + mapped_expr = mapped_name + new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr) + modified_formula = modified_formula.replace( + cast(str, ref_input.nctype.name), new_expr + ) + + nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name) + canonical_nctype = NamedCType( + nctype.name, nctype.type.remove_const_ref() + ) + saved_inputs.append( + SavedAttribute(nctype=canonical_nctype, expr=mapped_name) + ) + for ref_output in derivative.saved_outputs: + if ref_output.nctype.name == "result": + saved_outputs.append( + SavedAttribute( + nctype=NamedCType( + name="result", type=BaseCType(tensorListT) + ), + expr="result", + ) + ) + else: + raise RuntimeError("") + var_names = [map_refarg2foreacharg[var] for var in derivative.var_names] + all_var_names.extend(var_names) + all_saved_inputs.extend(saved_inputs) + all_saved_outputs.extend(saved_outputs) + modified_derivative = Derivative( + formula=modified_formula, + original_formula=derivative.formula, + var_names=tuple(var_names), + saved_inputs=tuple(saved_inputs), + saved_outputs=tuple(saved_outputs), + named_gradients=set(), + ) + modified_derivative_formulas.append(modified_derivative) + + with local.parametrize( + use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors, + use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group, + ): + args_with_derivatives = [ + Binding( + name=arg.name, + nctype=cpp.argument_type(arg, binds=arg.name), + argument=arg, + default=None, + ) + for arg in foreach_function.func.arguments.flat_non_out + if arg.name in all_var_names + ] + + forward_derivatives: list[ForwardDerivative] = [] + fw_derivative: ForwardDerivative + for fw_derivative in ref_diff_info.forward_derivatives: + var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef] + var_types: list[Type] = list(fw_derivative.var_types) + required_inputs_fw_grad: list[str] = [] + required_inputs_primal: list[str] = [] + if fw_derivative.required_inputs_fw_grad is not None: + required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) + if fw_derivative.required_inputs_primal: + required_inputs_primal = list(fw_derivative.required_inputs_primal) + modified_formula = fw_derivative.formula + + # Foreach's result is TensorList + if "result" in modified_formula: + modified_formula = fw_derivative.formula.replace("result", "result[i]") + + for foreach_arg, ref_arg in zip( + foreach_function.func.arguments.flat_non_out, + ref_diff_info.func.func.arguments.flat_non_out, + ): + # Modify reference forward formula + if ( + isinstance(foreach_arg.type, ListType) + and not foreach_arg.type.is_tensor_like() + ): + # Assuming ScalarList + modified_formula = modified_formula.replace( + ref_arg.name, foreach_arg.name + "[i]" + ) + elif foreach_arg.type.is_tensor_like(): + # Assuming TensorList / Tensor + # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}" + assert isinstance(foreach_arg.type, ListType) or ( + foreach_arg.type == BaseType(BaseTy.Tensor) + and str(foreach_function.func.name) in _foreach_with_tensor_overload + ), f"{foreach_function.func.name}, {foreach_arg.type}" + for suffix in ("_p", "_t"): + curr_expr = ref_arg.name + suffix + if curr_expr in modified_formula: + new_expr = foreach_arg.name + suffix + modified_formula = modified_formula.replace(curr_expr, new_expr) + else: + # Assuming Scalar + if foreach_arg.name != ref_arg.name: + modified_formula = modified_formula.replace( + ref_arg.name, foreach_arg.name + ) + + # note(crcrpar): there should exist a cooler way... + for i, name in enumerate(var_names): + if name == ref_arg.name: + var_names[i] = foreach_arg.name + var_types[i] = foreach_arg.type + for i, name in enumerate(required_inputs_fw_grad): + if name == ref_arg.name: + required_inputs_fw_grad[i] = foreach_arg.name + for i, name in enumerate(required_inputs_primal): + if name == ref_arg.name: + required_inputs_primal[i] = foreach_arg.name + forward_derivatives.append( + ForwardDerivative( + formula=modified_formula, + var_names=tuple(var_names), + var_types=tuple(var_types), + required_inputs_fw_grad=tuple(required_inputs_fw_grad), + required_inputs_primal=tuple(required_inputs_primal), + required_original_self_value=fw_derivative.required_original_self_value, + is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula, + ) + ) + + return ( + DifferentiabilityInfo( + name=foreach_function.func.name.name.base, + func=foreach_function, + op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}", + derivatives=modified_derivative_formulas, + forward_derivatives=forward_derivatives, + all_saved_inputs=tuple(set(all_saved_inputs)), + all_saved_outputs=tuple(set(all_saved_outputs)), + available_named_gradients=(), + used_named_gradients=set(), + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=[], + output_differentiability=None, + output_differentiability_conditions=None, + ), + True, + ) + + +def match_differentiability_info( + native_functions: list[NativeFunction], + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], +) -> list[NativeFunctionWithDifferentiabilityInfo]: + """Sets the "derivative" key on declarations to matching autograd function + In-place functions will use the out-of-place derivative definition if there + is no in-place specific derivative. + """ + + functional_info_by_signature = { + schema.signature(strip_default=True): info_dict + for schema, info_dict in differentiability_infos.items() + if schema.kind() == SchemaKind.functional + } + non_functional_info_by_signature = { + schema.signature(strip_default=True): info_dict + for schema, info_dict in differentiability_infos.items() + if schema.kind() != SchemaKind.functional + } + + def find_info( + f: NativeFunction, + ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]: + # Don't bother matching info to generated out= variants + if "generated" in f.tags and f.func.kind() == SchemaKind.out: + return None, False + + # (1) Check for an exact match + if f.func in differentiability_infos: + return differentiability_infos[f.func], True + + # (2) If no exact match, check if the out-of-place variant + # of this operator has a match. + # i.e mul() for mul_() or mul_out() + # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing + # native functions instead of the out-place counterparts. + f_sig = f.func.signature(strip_default=True) + if f_sig in functional_info_by_signature and not is_foreach_func(f): + return functional_info_by_signature[f_sig], False + + # (3) Some operators have a derivative explicitly defined for the mutable + # variant, but get a code-generated out-of-place variant which does *not* + # come with a derivative formula. + # For the generated out-of-place variant, use the mutable variant's formula + # if it exists. + if "generated" in f.tags and f_sig in non_functional_info_by_signature: + info_dict = non_functional_info_by_signature[f_sig] + # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 + assert not any( + any("self" in str(input.nctype.name) for input in info.all_saved_inputs) + for info in info_dict.values() + ), f"""\ +Attempted to convert a derivative formula for a mutable operator + to be used by automatically by its functional variant ("{str(f.func)}"). + this is not currently supported (we'd need to fix up the formula in the codegen).""" + return info_dict, False + + # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml` + if is_foreach_func(f): + assert f.func not in differentiability_infos + diff_info, is_generated = gen_foreach_derivativeinfo( + f, + functional_info_by_signature, + non_functional_info_by_signature, + ) + if diff_info is None: + return None, False + # TODO(crcrpar): Avoid hard coding "Default" ideally. + diff_info_dict = {"Default": diff_info} + if is_generated: + differentiability_infos[f.func] = diff_info_dict + functional_info_by_signature[f.func] = diff_info_dict + return diff_info_dict, is_generated + + return None, False + + result: list[NativeFunctionWithDifferentiabilityInfo] = [] + for f in native_functions: + info_dict, is_exact_match = find_info(f) + + # Currently, the '.strides()' to 'strides_or_error' replacement does not support + # 'self' derivatives of an inplace function, so we must check for this case. + if f.func.kind() == SchemaKind.inplace and (info_dict is not None): + for info in info_dict.values(): + for derivative in info.derivatives: + if "self" in derivative.var_names: + for saved_input in derivative.saved_inputs: + assert "strides_or_error" not in saved_input.expr, ( + "Calling '.strides()' in the 'self' derivative formula of an " + f"in-place function is not supported: {f.func}" + ) + + if not info_dict: + result.append( + NativeFunctionWithDifferentiabilityInfo( + func=f, info=None, fw_derivatives=None + ) + ) + continue + + fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {} + for key, info in info_dict.items(): + if not info.forward_derivatives: + fw_derivative_dict[key] = [] + continue + + forward_derivatives = info.forward_derivatives + + # For functions that have a single def for out-of-place and inplace (like abs()) + if f.func.kind() == SchemaKind.inplace: + # For inplace functions there is a little bit of work to do: + # 1) Validate the formula and make sure the input that is modified in not used: + # - If there is a formula for the inplace variant of the function (is_exact_match == True) then + # we make sure that the original value of the input that is being modified inplace (self_p) is + # not used in the formula. Note that the formula can use "original_self_p" here and that would + # trigger a clone of the original input. + # - If we are reusing the out of place formula (is_exact_match == False) then we replace every + # occurrence of self_p and self_t by original_self_p and original_self_t. These will be + # populated by cloned version of the original input (either the clone done by the backward AD + # logic if self is also used in a backward formula or a special clone that we add). + # 2) At this point, there cannot be a self_p in the formula. + # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is + # simply called self (as it is modified inplace). + # 4) Update the required primals data in case it used to contain "result" but should now contain + # "self" + # 5) If it is not an exact match, the user formula is not modifying the existing forward grad + # inplace as it should. So add some code that makes sure that we do so if the forward grad + # already exists. + + assert ( + len(info.forward_derivatives) == 1 + ) # Only single output inplace should exist + fw_info = info.forward_derivatives[0] + formula = fw_info.formula + + def replace_self_with_original_self(formula: str, postfix: str) -> str: + def repl(m: re.Match[str]) -> str: + return f"{m.group(1)}original_self{postfix}{m.group(2)}" + + return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) + + if re.search(IDENT_REGEX.format("self_p"), formula): + if is_exact_match: + # For manually defined formulas, don't allow the original value to be used + raise RuntimeError( + f'The formula for "{f.func.name}" is using the original value of self ' + "that is being modified inplace. This would lead to wrong forward gradients. " + 'Please use "result" in the formula only.' + ) + else: + # When the original formula is out of place, we save a clone of the primal + # value to be able to access this value if needed + # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" + formula = replace_self_with_original_self(formula, "_p") + formula = replace_self_with_original_self(formula, "_t") + + # replace "result" from the formula by "self_p" + def repl(m: re.Match[str]) -> str: + return f"{m.group(1)}self_p{m.group(2)}" + + formula = re.sub(IDENT_REGEX.format("result"), repl, formula) + + required_primals = fw_info.required_inputs_primal + if re.search(IDENT_REGEX.format("self_p"), formula): + required_primals = ( + required_primals + ("self",) if required_primals else ("self",) + ) + + if not is_exact_match: + # NOTE [In-place forward AD formula Optimization] + # + # This optimization transforms the formula to directly do inplace, i.e. + # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: + # + # 1) the formula satisfies the pattern: "self_t.op(*args)" + # 2) "op" in (1) needs to be the same as the op the derivative is for + # + # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) + # If there is a need, we can relax (2) to allow any op that has an in-place variant + is_single_method_on_self_t = False + directly_do_inplace = False + op_name: str | None = None + between_parens: str | None = None + match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) + if match: + op_name, between_parens = match.group(1), match.group(2) + + # We want to... + # Match: self_t.op1(other_p.op2(arg)) + # Avoid: self_t.op1(args) + self_t.op2(args) + # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) + def check_parens_nest_level_gt_zero(s: str) -> bool: + level = 1 + for ch in s: + if ch == ")": + level -= 1 + if level == 0: + return False + if ch == "(": + level += 1 + return True + + is_single_method_on_self_t = check_parens_nest_level_gt_zero( + between_parens + ) + directly_do_inplace = ( + is_single_method_on_self_t and op_name == info.name + ) + + if directly_do_inplace: + assert op_name is not None + assert between_parens is not None + formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" + else: + # Make sure that the forward grad is modified inplace when the original formula + # is out of place + formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" + + required_original_self_value = bool( + re.search(IDENT_REGEX.format("original_self_p"), formula) + ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula)) + + forward_derivatives = [ + ForwardDerivative( + formula=formula, + var_names=("self",), + var_types=fw_info.var_types, + required_inputs_fw_grad=fw_info.required_inputs_fw_grad, + required_inputs_primal=required_primals, + required_original_self_value=required_original_self_value, + is_reusing_outplace_formula=not is_exact_match, + ), + ] + + fw_derivative_dict[key] = forward_derivatives + + result.append( + NativeFunctionWithDifferentiabilityInfo( + func=f, info=info_dict, fw_derivatives=fw_derivative_dict + ) + ) + + return result + + +def is_differentiable( + name: str, type: Type, info: DifferentiabilityInfo | None +) -> bool: + return type.is_tensor_like() and ( + info is None or name not in info.non_differentiable_arg_names + ) + + +def gen_differentiable_outputs( + fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" +) -> list[DifferentiableOutput]: + f = fn.func + info = fn.info[key] if fn.info else None + outputs: list[DifferentiableOutput] = [ + DifferentiableOutput( + name=name, + type=ret.type, + cpp_type=cpp.return_type(ret, symint=True).cpp_type(), + ) + for name, ret in zip(cpp.return_names(f), f.func.returns) + ] + output_differentiability = info.output_differentiability if info else None + if output_differentiability is not None: + if len(output_differentiability) != len(outputs): + raise RuntimeError( + f"The length of output_differentiability ({len(output_differentiability)}), " + f"does not match the number of outputs ({len(outputs)})." + ) + differentiable_outputs: list[DifferentiableOutput] = [] + if False in output_differentiability and f.func.kind() == SchemaKind.inplace: + raise RuntimeError( + "output_differentiability=False for inplace operation (version_counter won't get updated)" + ) + for differentiable, output in zip(output_differentiability, outputs): + if differentiable: + differentiable_outputs.append(output) + return differentiable_outputs + candidate_differentiable_outputs = list( + filter(lambda r: is_differentiable(r.name, r.type, info), outputs) + ) + if uses_single_grad(info): + return candidate_differentiable_outputs[:1] + else: + return candidate_differentiable_outputs diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/meta.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..2e99d151faeaccea7ca47f372fd26f9985ce7249 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/meta.py @@ -0,0 +1,13 @@ +from torchgen.model import NativeFunctionsGroup + + +# Follows dispatcher calling convention, but: +# - Mutable arguments not allowed. Meta functions are always +# written in functional form. Look at FunctionSchema.signature() +# - No tensor returns; instead we return a TensorMeta describing +# the tensor in question + + +def name(g: NativeFunctionsGroup) -> str: + # use the overload name from the functional version + return str(g.functional.func.name).replace(".", "_") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/python.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/python.py new file mode 100644 index 0000000000000000000000000000000000000000..dbfa73060163057e979d231c06f63bb29ea87daa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/python.py @@ -0,0 +1,1548 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from torchgen.api import cpp +from torchgen.api.types import Binding, CppSignature, CppSignatureGroup +from torchgen.gen import pythonify_default +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + NativeFunction, + OptionalType, + Return, + Type, + Variant, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Data Models +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# [Notes] python binding codegen +# +# The Python binding codegen produces code that takes the input list of +# PyObjects, finds the matching ATen C++ function using PythonArgParser, +# converts the PyObjects into C++ types and calls the ATen C++ function: +# +# +--------+ parsing +------------------------+ binding +-----------------------+ +# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch | +# +--------+ +------------------------+ +-----------------------+ +# +# The following examples demonstrate the data models the Python binding +# codegen needs to deal with and the tasks it needs to accomplish. It +# helps understand the purpose of the new data types we introduced below. +# +# - Function Schema (source of truth) +# +# aten::empty.names(int[] size, *, Dimname[]? names, +# ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None, +# MemoryFormat? memory_format=None) -> Tensor +# +# - Python Signature +# +# It's used to generate input schema string for PythonArgParser. +# Note: TensorOptions fields are reordered and the additional +# 'requires_grad' field is added: +# +# empty(IntArrayRef size, *, DimnameList? names, +# MemoryFormat? memory_format=None, ScalarType dtype=None, +# Layout layout=torch.strided, Device device=None, +# bool pin_memory=False, bool requires_grad=False) +# +# - C++ Signature +# +# It's used to generate C++ lambda formals & dispatch call. +# Note: the scattered TensorOptions fields are packed into 'options'. +# +# auto dispatch_empty = +# [](IntArrayRef size, std::optional names, +# const TensorOptions & options, +# std::optional memory_format) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return torch::empty(size, names, options, memory_format); +# }; +# +# - Binding between Python Arguments and C++ Arguments +# +# Given a set of Python Arguments in scope, we need produce the +# binding expressions that translate the Python API into C++ API: +# +# Python Args Cpp Args Binding Exprs +# ----------------------------------------------------------------- +# 0: size size '_r.intlist(0)' +# 1: names names 'names' [special init] +# 2: memory_format -------+ +# 3: dtype -----+-|--> options 'options' [special packing] +# 4: layout / | +# 5: device / +--> memory_format '_r.memoryformatOptional(2)' +# 6: pin_memory / +# 7: requires_grad -+ +# +# So the full dispatch expression would look like: +# +# dispatch_empty(_r.intlist(0), names, options, +# _r.memoryformatOptional(2)) +# +# Where does 'names' come from? It involves special local init: +# +# auto __names = _r.toDimnameListOptional(1); +# std::optional names = +# __names ? std::make_optional(DimnameList(__names.value())) +# : std::nullopt; +# +# Where does 'options' come from? It involves special local init +# for TensorOptions. Note that Python side has the additional +# 'requires_grad' field: +# +# const auto options = TensorOptions() +# .dtype(_r.scalartype(3)) +# .device(_r.device(5)) +# .layout(_r.layoutOptional(4)) +# .requires_grad(_r.toBool(7)) +# .pinned_memory(_r.toBool(6)); +# +# In some other cases one Python Argument can map to multiple C++ +# Arguments. For example: +# +# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) +# -> (Tensor values, Tensor indices) +# +# Python Args Cpp Args Binding Exprs +# --------------------------------------------------------------------- +# +----> max 'out[0]' +# /-----> max_values 'out[1] +# 0: input / self '_r.tensor(0)' +# 1: dim / dim '_r.dimname(1)' +# 2: keepdim / keepdim '_r.toBool(2)' +# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)' +# +# As demonstrated above, the binding can involve reordering, +# packing, unpacking and special local inits. +# +# +# Let's look at a concrete example: +# +# static PythonArgParser parser({ +# "abs(Tensor input, *, Tensor out=None)", +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- Python Schema, represented by PythonSignature and PythonArgument +# +# }, /*traceable=*/true); +# +# ParsedArgs<2> parsed_args; +# auto _r = parser.parse(nullptr, args, kwargs, parsed_args); +# +# ... +# +# if (_r.isNone(1)) { +# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out') +# represented by PythonArgParserOutputExpr +# +# // aten::abs(Tensor self) -> Tensor +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, base version +# +# auto dispatch_abs = [](const Tensor & self) -> Tensor { +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_args / dispatch_lambda_return_str +# generated from NativeFunction / CppSignature +# (deprecated PythonSignature is special) +# arguments are represented by DispatchLambdaArgument +# +# pybind11::gil_scoped_release no_gil; +# return self.abs(); +# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs +# generated from NativeFunction / CppSignature +# }; +# return wrap(dispatch_abs(_r.tensor(0))); +# ~~~~~~~~~~~~~ +# ^ +# +--- dispatch_lambda_exprs +# binding PythonArgParserOutputExpr (python args) +# and DispatchLambdaArgument (c++ args) +# +# } else { +# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ^ +# +--- NativeFunction schema, out-variant +# +# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::abs_out(out, self); +# }; +# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0))); +# } +# +# +# [Notes] python interface codegen +# The python dataclasses below are used used to generate both python binding code +# and pyi type hint signatures. +# In theory these two should look very similar, but there are number of differences +# in how pyi signatures vs. python_arg_parser signatures are generated. +# These differences have been encapsulated in signature_str() vs. signature_str_pyi() +# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments. +# For examples, only pyi signatures include return types. + + +def format_function_signature( + name: str, arguments: Iterable[str] = (), return_type: str | None = None +) -> str: + if not isinstance(arguments, (list, tuple)): + arguments = tuple(arguments) + return_type = f" -> {return_type}" if return_type is not None else "" + + sig = f"def {name}({', '.join(arguments)}){return_type}: ..." + if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",): + return sig + + lines = [ + f"def {name}(", + *(f" {arg}," for arg in arguments), + f"){return_type}: ...", + ] + sig = "\n".join(lines) + if all(len(line) <= 80 for line in lines): + return sig + # ruff format bug for compound statements: https://github.com/astral-sh/ruff/issues/18658 + # use `skip` instead of `on` + `off` + return sig.removesuffix(" ...") + " # fmt: skip\n ..." + + +@dataclass(frozen=True) +class PythonReturns: + returns: tuple[Return, ...] + + +@dataclass(frozen=True) +class PythonArgument: + name: str + type: Type + default: str | None + + # Used to generate the default init expr for some PythonArgParser outputs, e.g.: + # + # _r.layoutWithDefault(3, layout_from_backend(self.options().backend()))) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # ^ + # +--- default_init str + default_init: str | None + + # Compute argument formal for python argument parsing. + # Needs to be consistent with torch/csrc/utils/python_arg_parser.h. + def argument_str(self, *, method: bool = False, symint: bool = True) -> str: + type_str = ( + argument_type_str(self.type, symint=symint) + .replace("const ", "") + .replace(" &", "") + ) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == "self" and type_str in ["Tensor", "Number"] and not method: + name = "input" + + # add default + if self.default is not None: + default = { + "nullptr": "None", + "::std::nullopt": "None", + "std::nullopt": "None", + "{}": "None", + }.get(self.default, self.default) + return f"{type_str} {name}={default}" + else: + return f"{type_str} {name}" + + def argument_str_pyi( + self, *, method: bool = False, deprecated: bool = False + ) -> str: + type_str = argument_type_str_pyi(self.type) + + name = self.name + # s/self/input/ outside method bindings + # [old codegen] TODO: remove this? doesn't rename in codegen, it's just + # for the parse string + if name == "self" and type_str == "Tensor" and not method and not deprecated: + name = "input" + + if name == "from": # from is a Python keyword... + name += "_" + + # pyi merges the _out and functional variants into the same signature, with an optional out arg + if name == "out" and type_str == "Tensor" and not deprecated: + type_str = f"{type_str} | None".replace(" | None | None", " | None") + + # pyi deprecated signatures don't get defaults for their out arg + treat_as_no_default = ( + deprecated + and isinstance(self, PythonOutArgument) + and self.default == "None" + ) + + # add default + if self.default is not None and not treat_as_no_default: + if ( + isinstance(self.type, ListType) + and self.type.elem == BaseType(BaseTy.int) + and self.default.startswith("{") + and self.default.endswith("}") + ): + default = ( + "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")" + ) + else: + default = { + "nullptr": "None", + "::std::nullopt": "None", + "std::nullopt": "None", + "{}": "None", + "c10::MemoryFormat::Contiguous": "contiguous_format", + "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine", + }.get(self.default, self.default) + return f"{name}: {type_str} = {default}" + else: + return f"{name}: {type_str}" + + +@dataclass(frozen=True) +class PythonOutArgument(PythonArgument): + # In Python signature multiple output fields are packed into one 'out' argument. + # When binding to C++, it's first binded to a local 'out' variable: + # 'auto out = _r.tensorlist_n<2>(2);', + # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. + # TODO: maybe don't need keep scattered out fields for python signature? + outputs: tuple[PythonArgument, ...] + + @staticmethod + def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None: + if not outputs: + return None + + size = len(outputs) + if size == 1: + return PythonOutArgument( + name=outputs[0].name, + type=outputs[0].type, + default="None", + default_init=None, + outputs=outputs, + ) + elif size > 1: + if any(not a.type.is_tensor_like() for a in outputs): + raise RuntimeError(f"Unsupported output type: {outputs}") + return PythonOutArgument( + name="out", + # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None? + type=ListType(BaseType(BaseTy.Tensor), size), + default="None", + default_init=None, + outputs=outputs, + ) + raise AssertionError(r"Unexpected PythonOutArgument size") + + +@dataclass(frozen=True) +class PythonSignature: + # Base operator name, without inplace/outplace suffix. + name: str + + # Positional arguments. + # TODO: create a dedicated SelfArgument type for 'self'? + input_args: tuple[PythonArgument, ...] + + # Keyword arguments excluding the 'out' argument and scattered kwargs belonging + # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). + input_kwargs: tuple[PythonArgument, ...] + + output_args: PythonOutArgument | None + + # Return types, which are only used by pyi + returns: PythonReturns + + # These are scattered kwargs arguments belonging to TensorOptions. + # When binding to C++, they are packed into a TensorOptions object 'options'. + # It's possible that the C++ signature doesn't take TensorOptions object (e.g. + # for out variant), in which case they will be used as scattered fields without + # being packed into 'options'. + # TODO: maybe create a PythonTensorOptionsArgument? + tensor_options_args: tuple[PythonArgument, ...] + + # method or function signature? + method: bool + + @property + def deprecated(self) -> bool: + return False + + def arguments( + self, *, skip_outputs: bool = False, skip_tensor_options: bool = False + ) -> tuple[PythonArgument | PythonOutArgument, ...]: + result: list[PythonArgument | PythonOutArgument] = [] + result.extend(self.input_args) + result.extend(self.input_kwargs) + if self.output_args is not None and not skip_outputs: + result.append(self.output_args) + if not skip_tensor_options: + result.extend(self.tensor_options_args) + return tuple(result) + + def arguments_count(self) -> int: + return len(self.arguments()) + + def output_idx(self) -> int: + return len(self.input_args) + len(self.input_kwargs) + + # [old codegen] Compute the Python function signature for argument parsing, + # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: + # this is NOT the same type signature as specified by PEP 484 + # as understood by mypy; our format was independently developed + # and has some quirks to make it more suitable specifically + # for error parsing. + # + # For a translation to mypy-valid type signatures, see + # signature_str_pyi(). + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str(method=self.method, symint=symint) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + return f"{self.name}({', '.join(schema_formals)})" + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + # only pyi signatures include returns + returns_str = returns_str_pyi(self) + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return format_function_signature(self.name, schema_formals, returns_str) + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + # only pyi uses vararg signatures + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method) for a in args + ] + # vararg only applies to pyi signatures. vararg variants are not generated for all signatures + num_args = self.arguments_count() + if num_args == 0: + return None + + num_positionalargs = len(self.input_args) + + vararg_type = args[0].type + if not ( + isinstance(vararg_type, ListType) + and str(vararg_type.elem) in ["int", "SymInt"] + and num_positionalargs == 1 + ): + return None + + # Below are the major changes in vararg vs. regular pyi signatures + # vararg signatures also omit the asterix + assert isinstance(vararg_type, ListType) + schema_formals[0] = ( + "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem) + ) + + returns_str = returns_str_pyi(self) + # pyi also includes self (with no typing/defaults) for methods + if self.method: + schema_formals.insert(0, "self") + return format_function_signature(self.name, schema_formals, returns_str) + + +# The deprecated python signature involves some special logic, so create a +# dedicated data model to store these extra properties. +@dataclass(frozen=True) +class PythonSignatureDeprecated(PythonSignature): + # Schema for the deprecated function + deprecated_schema: FunctionSchema + + # The deprecated signature might miss some arguments that the corresponding + # C++ signature expects. We need store the constant default values to pass in. + # For example: + # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) + # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + # [func call]: self.addmm(mat1, mat2, beta, 1) + # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. + deprecated_args_exprs: tuple[str, ...] + + @property + def deprecated(self) -> bool: + return True + + def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: + return ( + PythonSignature.signature_str( + self, skip_outputs=skip_outputs, symint=symint + ) + + "|deprecated" + ) + + def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: + args = self.arguments(skip_outputs=skip_outputs) + schema_formals: list[str] = [ + a.argument_str_pyi(method=self.method, deprecated=True) for a in args + ] + positional_argc = len(self.input_args) + if len(schema_formals) > positional_argc: + schema_formals.insert(positional_argc, "*") + + returns_str = returns_str_pyi(self) + return format_function_signature(self.name, schema_formals, returns_str) + + def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None: + # the codegen doesn't include vararg variants for deprecated signatures + return None + + +# This struct is used to hold the PythonSignature and its corresponding +# NativeFunction BEFORE grouping base and out-variant functions. +# Why not store NativeFunction in PythonSignature or construct PythonSignature +# from NativeFunction? Because they are not 1-1 mapped. +# One native function could have both deprecated and non-deprecated python +# signatures - NativeFunction doesn't contain information to construct the +# deprecated python signature. +# One python signature is used to handle both the base and the out-variant +# function - see 'PythonSignatureGroup'. +@dataclass(frozen=True) +class PythonSignatureNativeFunctionPair: + signature: PythonSignature + function: NativeFunction + + +# We merge pairs of functions with signatures that are equivalent mod +# output arguments, and use a single entry in the python_arg_parser sig +# list for both (output arguments become optional). +@dataclass(frozen=True) +class PythonSignatureGroup: + # The signature used for Python argument parsing. The outplace signature + # is preferred if exists, because it can be used to parse inputs for both + # the out-place variant and the base version (with output omitted). + signature: PythonSignature + + # The regular ATen declaration (e.g. conv2d) + base: NativeFunction + + # The out variant (e.g. conv2d_out) + outplace: NativeFunction | None + + @classmethod + def from_pairs( + cls, + functional: PythonSignatureNativeFunctionPair, + out: PythonSignatureNativeFunctionPair | None, + ) -> PythonSignatureGroup: + if out is None: + return PythonSignatureGroup( + signature=functional.signature, + base=functional.function, + outplace=None, + ) + + # prefer the signature with optional out=... arguments because it's the + # superset that can be used to parse input for both base and outplace. + signature_kwargs = out.signature.__dict__.copy() + + # Out overloads in C++ don't have TensorOptions arguments, + # so take these from the functional variant + signature_kwargs["tensor_options_args"] = ( + functional.signature.tensor_options_args + ) + + return PythonSignatureGroup( + signature=type(out.signature)(**signature_kwargs), + base=functional.function, + outplace=out.function, + ) + + +# C++ function dispatch is wrapped in a lambda function. The lambda function +# has almost the same signature as the C++ function, only with some small +# variants - see details below. +# This data model is used to represent arguments of the lambda function +# signature. +@dataclass(frozen=True) +class DispatchLambdaArgument: + name: str + type_str: str + is_out_arg: bool + + +# To pass PyObjects arguments to C++ function (via the lambda wrapper), +# we need first convert PyObjects into simple C++ objects. This work +# is done by PythonArgParser. +# This data model is used to represent the output of PythonArgParser. +# It has 1-1 mapping with PythonArgument in PythonSignature. +@dataclass(frozen=True) +class PythonArgParserOutputExpr: + # argument name + name: str + + # RHS expression to reference PythonArgParser output. + expr: str + + # In some special cases we need create different expr, e.g.: + # '_r.isNone(1)' instead of '_r.tensor(1)'. + index: int + + # The python argument it maps to. + argument: PythonArgument + + @property + def is_none_expr(self) -> str: + return f"_r.isNone({self.index})" + + +# To pass PythonArgParser output to the lambda wrapper, we need bind +# PythonArgParserOutputExpr to DispatchLambdaArgument. +# They are not always 1-1 mapped, e.g. scattered TensorOptions fields +# need be packed into a TensorOptions object, which is the argument +# that the lambda function wrapper takes. +@dataclass(frozen=True) +class DispatchLambdaArgumentExprs: + # The exprs that provide the binding for lambda arguments, e.g.: + # + # 'self' -> '_r.tensor(0)' + # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]' + # 'options' -> 'options' + # + # It has 1-1 mapping with DispatchLambdaArgument. + exprs: Sequence[str] + + # Special local inits, which might introduce new variables that + # the 'exprs' above reference, e.g.: + # + # 'auto out = _r.tensorlist_n<2>(2);' + # + inits: Sequence[str] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Helper Functions +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature: + return CppSignatureGroup.from_native_function(f, method=method).signature + + +def has_tensor_options(f: NativeFunction) -> bool: + return f.func.arguments.tensor_options is not None + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Signature +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# 'simple_type' was introduced by the old codegen, which is slightly +# different from the python schema type, e.g.: doesn't have '?' suffix +# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type. +def argument_type_str( + t: Type, *, simple_type: bool = False, symint: bool = True +) -> str: + if isinstance(t, BaseType): + if t.name == BaseTy.int: + return "int64_t" + elif t.name == BaseTy.float: + return "double" + elif t.name == BaseTy.str: + return "c10::string_view" + elif t.name in [ + BaseTy.Tensor, + BaseTy.bool, + BaseTy.QScheme, + BaseTy.Scalar, + BaseTy.ScalarType, + BaseTy.Generator, + BaseTy.Storage, + BaseTy.Layout, + BaseTy.Device, + BaseTy.DeviceIndex, + BaseTy.MemoryFormat, + BaseTy.Dimname, + BaseTy.Stream, + BaseTy.SymInt, + ]: + # These python schema type names line up with their function schema names + return t.name.name + + elif isinstance(t, OptionalType): + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) + return f"{elem}?" + elif isinstance(t, ListType): + size = t.size if not simple_type else None + if str(t.elem) == "bool": + assert t.size is not None + return f"::std::array" + elif str(t.elem) == "int": + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" + elif str(t.elem) == "SymInt": + if symint: + return ( + f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef" + ) + else: + return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef" + elif str(t.elem) == "Tensor": + return f"TensorList[{size}]" if size is not None else "TensorList" + elif str(t.elem) == "Scalar": + return f"ScalarList[{size}]" if size is not None else "ScalarList" + elif str(t.elem) == "Tensor?": + if simple_type: + return "c10::List<::std::optional>" + else: + return "const c10::List<::std::optional> &" + elif str(t.elem) == "Dimname": + return f"DimnameList[{size}]" if size is not None else "DimnameList" + elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint) + return f"ArrayRef<{elem}>" + + raise RuntimeError(f"unrecognized type {repr(t)}") + + +def argument_type_size(t: Type) -> int | None: + l = t.is_list_like() + if l is not None and str(l.elem) != "bool": + return l.size + else: + return None + + +def argument(a: Argument) -> PythonArgument: + return PythonArgument( + name=a.name, + type=a.type, + # TODO: directly translate a.default to python default + default=( + str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False))) + if a.default is not None + else None + ), + default_init=None, + ) + + +# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen +def signature( + f: NativeFunction, *, method: bool = False, pyi: bool = False +) -> PythonSignature: + return signature_from_schema( + f.func, category_override=f.category_override, method=method, pyi=pyi + ) + + +def signature_from_schema( + func: FunctionSchema, + *, + category_override: str | None, + method: bool = False, + pyi: bool = False, +) -> PythonSignature: + args: list[Argument] = [] + args.extend(func.arguments.pre_self_positional) + # Skip SelfArgument if this is method. + if not method and func.arguments.self_arg is not None: + args.append(func.arguments.self_arg.argument) + args.extend(func.arguments.post_self_positional) + args.extend(func.arguments.pre_tensor_options_kwarg_only) + # Skip TensorOptionsArguments. Python side TensorOptions + # arguments are created based on different rules - see below. + args.extend(func.arguments.post_tensor_options_kwarg_only) + args.extend(func.arguments.out) + + input_arg_set = {a.name for a in func.arguments.flat_positional} + kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only} + out_arg_set = {a.name for a in func.arguments.out} + + input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args))) + input_kwargs = tuple( + map(argument, filter(lambda a: a.name in kwarg_only_set, args)) + ) + outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args))) + + # Reintroduce the scattered fields of TensorOptions for Python. + # Compared to the cpp counterpart, the python arguments have new property + # (default_init) and a new argument 'requires_grad', which require some + # special handlings. + # [old codegen] TODO: because these aren't guaranteed to be 100% faithful + # to the original versions in the yaml, this recreation is a potential + # source of drift between eager and JIT. Pull this logic out to a shared place. + + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in func.arguments.flat_non_out + ) + if any(a.name == "requires_grad" for a in func.schema_order_arguments()): + raise ValueError( + "argument named requires_grad is reserved, should not explicitly add it in the schema" + ) + + # [old codegen] this probably won't work if one of the returns is not a tensor, + # but it will produce a compile-time error that is obvious. + has_tensor_return = any(r.type.is_tensor_like() for r in func.returns) + + name: str = cpp.name(func) + is_factory_function = category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + is_like_or_new_function = ( + category_override in ("new", "like") + or name.startswith("new_") + or name.endswith("_like") + ) + is_dummy_function = category_override == "dummy" + + tensor_options_args: list[PythonArgument] = [] + if (is_factory_function or is_like_or_new_function) and not is_dummy_function: + + def topt_default_init(name: str) -> str | None: + topt_args = func.arguments.tensor_options + if topt_args is None: + return None + a = getattr(topt_args, name) + if a.default is None or a.default == "None": + return None + return cpp.default_expr(a.default, a.type, symint=False) + + tensor_options_args.append( + PythonArgument( + name="dtype", + type=OptionalType(BaseType(BaseTy.ScalarType)), + default="None", + default_init=( + None if is_like_or_new_function else topt_default_init("dtype") + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="layout", + type=OptionalType(BaseType(BaseTy.Layout)), + default="None", + default_init=( + None if is_like_or_new_function else topt_default_init("layout") + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="device", + type=OptionalType(BaseType(BaseTy.Device)), + default="None", + default_init=( + None + if is_like_or_new_function + else ( + topt_default_init("device") + or "torch::tensors::get_default_device()" + ) + ), + ) + ) + tensor_options_args.append( + PythonArgument( + name="pin_memory", + type=OptionalType(BaseType(BaseTy.bool)), + default="False", + default_init=None, + ) + ) + tensor_options_args.append( + PythonArgument( + name="requires_grad", + type=OptionalType(BaseType(BaseTy.bool)), + default="False", + default_init=None, + ) + ) + + returns = PythonReturns(returns=func.returns) + + return PythonSignature( + name=str(func.name.name), + input_args=input_args, + input_kwargs=input_kwargs, + output_args=PythonOutArgument.from_outputs(outputs), + tensor_options_args=tuple(tensor_options_args), + returns=returns, + method=method, + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python Interface +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]: + if len(returns) <= 1 or all(r.name is None for r in returns): + return [] + else: + if any(r.name is None for r in returns): + # When building on Windows, `PyStructSequence_UnnamedField` could not be + # resolved by the linker for some reason, which cause error in building: + # + # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol + # PyStructSequence_UnnamedField + # + # Thus, at this point in time, we do not support unnamed + # fields in structseq; you must either name all fields, + # or none of them. + raise ValueError("Unnamed field is not supported by codegen") + + return [str(r.name) for r in returns] + + +def argument_type_str_pyi(t: Type) -> str: + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + + ret = "" + if isinstance(t, BaseType): + if t.name in [BaseTy.int, BaseTy.DeviceIndex]: + ret = "_int" + if t.name == BaseTy.SymInt: + ret = "_int | SymInt" + elif t.name == BaseTy.float: + ret = "_float" + elif t.name == BaseTy.str: + ret = "str" + elif t.name == BaseTy.Scalar: + ret = "Number | _complex" + elif t.name == BaseTy.ScalarType: + ret = "_dtype" + elif t.name == BaseTy.bool: + ret = "_bool" + elif t.name == BaseTy.QScheme: + ret = "_qscheme" + elif t.name == BaseTy.Layout: + ret = "_layout" + elif t.name == BaseTy.Device: + ret = "DeviceLikeType | None" + elif t.name == BaseTy.MemoryFormat: + ret = "memory_format" + elif t.name == BaseTy.Dimname: + ret = "str | EllipsisType | None" + elif t.name == BaseTy.Storage: + ret = "Storage | UntypedStorage" + elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]: + # These python schema type names line up with their function schema names + ret = t.name.name + + elif isinstance(t, ListType): + if str(t.elem) == "int": + ret = "_int | _size" if t.size is not None else "_size" + elif t.is_tensor_like(): + # TODO: this doesn't seem right... + # Tensor?[] currently translates to tuple[Tensor, ...] | list[Tensor] | None + # It should probably translate to tuple[Tensor | None, ...] | list[Tensor | None] + add_optional = True + ret = ( + "Tensor | tuple[Tensor, ...] | list[Tensor]" + if t.size is not None + else "tuple[Tensor, ...] | list[Tensor]" + ) + elif str(t.elem) == "float": + ret = "Sequence[_float]" + elif str(t.elem) == "SymInt" and t.size is not None: + elem = argument_type_str_pyi(t.elem) + ret = f"{elem} | Sequence[{elem}]" + else: + elem = argument_type_str_pyi(t.elem) + ret = f"Sequence[{elem}]" + + else: + raise RuntimeError(f"unrecognized type {repr(t)}") + + if add_optional: + ret = f"{ret} | None".replace(" | None | None", " | None") + + return ret + + +def return_type_str_pyi(t: Type) -> str: + # Where arguments are open to accepting Union, return types should return + # concrete types + + if isinstance(t, OptionalType): + inner = return_type_str_pyi(t.elem) + return f"{inner} | None".replace(" | None | None", " | None") + + if isinstance(t, BaseType): + if t.name == BaseTy.Device: + return "_device" + elif t.name == BaseTy.Dimname: + return "str | None" + else: + return argument_type_str_pyi(t) + + if isinstance(t, ListType): + inner = return_type_str_pyi(t.elem) + return f"tuple[{inner}, ...]" + + return argument_type_str_pyi(t) + + +def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None: + python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] + structseq_name = signature.name + field_names = structseq_fieldnames(signature.returns.returns) + if field_names: + # These types are structseq objects which act like named NamedTuples, but + # the constructor acts like the constructor of tuple. Using typing.NamedTuple + # does not allow us to override __init__. + seq_type = f"tuple[{', '.join(python_returns)}]" + structseq_def_lines = [ + f"class {structseq_name}({seq_type}): # fmt: skip", + ] + for name, ret_type in zip(field_names, python_returns): + structseq_def_lines.extend( + [ + " @property", + f" def {name}(self) -> {ret_type}: ...", + ] + ) + structseq_def_lines.extend( + [ + " def __new__(", + " cls,", + f" sequence: {seq_type},", + " ) -> Self: # fmt: skip", + " ...", + f" n_fields: Final[_int] = {len(field_names)}", + f" n_sequence_fields: Final[_int] = {len(field_names)}", + " n_unnamed_fields: Final[_int] = 0", + " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", + "", # add an extra newline + ] + ) + structseq_def = "\n".join(structseq_def_lines) + # Example: + # structseq_def = ( + # "class max(tuple[Tensor, Tensor]): # fmt: skip\n" + # " @property\n" + # " def values(self) -> Tensor: ...\n" + # " @property\n" + # " def indices(self) -> Tensor: ...\n" + # " def __new__(\n" + # " cls,\n" + # " sequence: tuple[Tensor, Tensor],\n" + # " ) -> Self: # fmt: skip\n" + # " ...\n" + # " n_fields: Final[_int] = 2", + # " n_sequence_fields: Final[_int] = 2", + # " n_unnamed_fields: Final[_int] = 0", + # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing", + # ) + return structseq_name, structseq_def + return None + + +def returns_str_pyi(signature: PythonSignature) -> str: + field_names = structseq_fieldnames(signature.returns.returns) + if field_names: + return f"torch.return_types.{signature.name}" + + python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] + if len(python_returns) > 1: + return "tuple[" + ", ".join(python_returns) + "]" + if len(python_returns) == 1: + return python_returns[0] + return "None" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# C++ Function Dispatch +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# This section provides APIs to generate the code that does C++ function +# dispatch. The C++ function call is wrapped by a lambda function. +# For example: +# +# // aten::selu_(Tensor(a!) self) -> Tensor(a!) +# auto dispatch_selu_ = [](Tensor self) -> Tensor { +# pybind11::gil_scoped_release no_gil; +# return at::selu_(self); +# }; +# +# The lambda function's signature follows the C++ signature in common +# cases, e.g.: +# +# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor +# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For out variant the 'out' argument's type is changed from 'Tensor &' +# to 'Tensor'. It's because when calling the lambda it passes in the +# PythonArgParser output '_r.tensor(3)', which is stack allocated object +# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'. +# +# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) +# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor +# +# For multi-output case it can keep using reference type because the +# PythonArgParser output has been unpacked to local variables, e.g.: +# +# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, +# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) +# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple +# +# For deprecated python signature, it should follow deprecated python arg order. +# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary? + + +def dispatch_lambda_args( + ps: PythonSignature, f: NativeFunction, symint: bool = True +) -> tuple[DispatchLambdaArgument, ...]: + if isinstance(ps, PythonSignatureDeprecated): + schema = ps.deprecated_schema + else: + schema = f.func + + # Start with cpp arguments - dispatch lambda signature always include 'self' + cpp_args = cpp.arguments( + arguments=schema.arguments, + faithful=False, + symint=symint, + method=False, + cpp_no_default_args=f.cpp_no_default_args, + ) + out_args: set[str] = {a.name for a in schema.arguments.out} + + # Convert from cpp argument to lambda argument + def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: + type_str = cpp_arg.type + is_out_arg = cpp_arg.name in out_args + if ps.method and cpp_arg.name == "self": + # For method's 'self', we can use 'const Tensor &' and simply ignore mutability! + type_str = "const at::Tensor &" + else: + # For other cases we need prevent dangling refs to temps (unless it's + # unpacked scattered output) + # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'. + # TODO: avoid this special handling? + ensure_temp_safe = len(out_args) <= 1 or not is_out_arg + if ensure_temp_safe: + type_str = { + "at::Tensor &": "at::Tensor", + }.get(type_str, type_str) + return DispatchLambdaArgument( + name=cpp_arg.name, + type_str=type_str, + is_out_arg=is_out_arg, + ) + + return tuple(map(dispatch_lambda_arg, cpp_args)) + + +# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean +# it's enough to just extend the list here. Before you do this, make sure +# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. +SUPPORTED_RETURN_TYPES = { + "at::Tensor", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple", + "::std::tuple>", + "::std::vector", + # Needed for flash attention forw/backward + "::std::tuple", + "at::Scalar", + "bool", + "int64_t", + "void*", + "void", + "at::QScheme", + "double", + "at::IntArrayRef", + "at::ScalarType", + "at::Stream", +} + + +def dispatch_lambda_return_str(f: NativeFunction) -> str: + # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &') + # because the dispatch lambdas take mutable arguments *by value*, not + # by reference. If you then return a reference to such an argument, you + # will now have a pointer to a dangling stack entry. Not good. + # + # You want: + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; + # ^^^^^^ + # + # *not* + # + # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; + # ^^^^^^^ + # + # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing + # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a + # mutable reference to temporary. Maybe we could assign it to a + # variable itself.) + returns_without_annotation = tuple( + Return(r.name, r.type, None) for r in f.func.returns + ) + return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type() + if return_str not in SUPPORTED_RETURN_TYPES: + raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}") + return return_str + + +def cpp_dispatch_target(f: NativeFunction) -> str: + symint = f.func.has_symint() + name = cpp.name(f.func, symint_overload=symint) + if Variant.method in f.variants: + return f"self.{name}" + if Variant.function in f.variants: + if has_tensor_options(f) or f.func.name.name.base.endswith("_like"): + namespace = "torch" + else: + namespace = "at" + return f"{namespace}::{name}" + raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}") + + +def cpp_dispatch_exprs( + f: NativeFunction, + *, + python_signature: PythonSignature | None = None, +) -> tuple[str, ...]: + cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() + + exprs: tuple[str, ...] = () + if not isinstance(python_signature, PythonSignatureDeprecated): + # By default the exprs are consistent with the C++ signature. + exprs = tuple(a.name for a in cpp_args) + else: + # For deprecated python signature we may need fill in some constants. + exprs = tuple( + filter( + lambda n: n != "out" or f.func.is_out_fn(), + python_signature.deprecated_args_exprs, + ) + ) + + if Variant.method in f.variants: + exprs = tuple(filter("self".__ne__, exprs)) + + return exprs + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Python / C++ Args Binding +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +# We explicitly enumerate the PythonArgParser unpacking methods for all +# supported types. This might be more verbose than necessary, partially +# because of the irregularity of unpacking method naming, partially +# because we want to mimic the old codegen behavior - to reject +# unexpected and/or unsupported cases which the old codegen rejects. +# For certain cases it is intentionally more restrictive than necessary, +# e.g.: it doesn't accepts doublelist with definite size. +def arg_parser_unpack_method( + t: Type, default: str | None, default_init: str | None, *, symint: bool = True +) -> str: + has_default_init = default_init is not None + if has_default_init and str(t) not in ( + "ScalarType?", + "ScalarType", + "Device", + "Device?", + "Layout", + "Layout?", + "bool", + "bool?", + ): + raise RuntimeError(f"type '{t}' does not supported unpacking with default") + + if isinstance(t, BaseType): + if t.name in [ + BaseTy.Tensor, + BaseTy.Stream, + BaseTy.Storage, + BaseTy.Scalar, + BaseTy.Dimname, + ]: + # These unpack methods line up with their schema names + return t.name.name.lower() + elif t.name == BaseTy.ScalarType: + return "scalartypeWithDefault" if has_default_init else "scalartype" + elif t.name == BaseTy.Device: + return "deviceWithDefault" if has_default_init else "device" + elif t.name == BaseTy.DeviceIndex: + return "toInt64" + elif t.name == BaseTy.int: + return "toInt64" + elif t.name == BaseTy.SymInt: + return "toSymInt" if symint else "toInt64" + elif t.name == BaseTy.bool: + return "toBoolWithDefault" if has_default_init else "toBool" + elif t.name == BaseTy.float: + return "toDouble" + elif t.name == BaseTy.str: + return "stringView" + elif t.name == BaseTy.Layout: + return "layoutWithDefault" if has_default_init else "layout" + elif t.name == BaseTy.MemoryFormat: + return "memoryformat" + + elif isinstance(t, OptionalType): + if str(t.elem) == "Tensor": + return "optionalTensor" + elif str(t.elem) == "Generator": + return "generator" + elif str(t.elem) == "Dimname[]": + return "toDimnameListOptional" + elif not has_default_init and default in ( + None, + "None", + "::std::nullopt", + "std::nullopt", + ): + # If default is None: append 'Optional' to elem's unpacking method + return ( + arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional" + ) + else: + # Otherwise, load as underlying type with default + return arg_parser_unpack_method( + t.elem, default, default_init, symint=symint + ) + + elif isinstance(t, ListType): + if str(t.elem) == "Tensor": + # accept and use definite size + return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist" + elif str(t.elem) == "Tensor?": + return "list_of_optional_tensors" + elif str(t.elem) == "Dimname": + # accept definite size + return "dimnamelist" + elif str(t.elem) == "int": + # accept definite size + return "intlist" + elif str(t.elem) == "float": + return "doublelist" + elif str(t.elem) == "SymInt": + # accept definite size + return "symintlist" if symint else "intlist" + elif str(t.elem) == "Scalar": + return "scalarlist" + raise RuntimeError(f"type '{t}' is not supported by PythonArgParser") + + +# Return RHS expression for python argument using PythonArgParser output. +# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' +def arg_parser_output_expr( + arg_index: int, a: PythonArgument, *, symint: bool = True +) -> PythonArgParserOutputExpr: + has_default = a.default_init is not None + unpack_method = arg_parser_unpack_method( + t=a.type, default=a.default, default_init=a.default_init, symint=symint + ) + default = f", {a.default_init}" if has_default else "" + expr = f"_r.{unpack_method}({arg_index}{default})" + + return PythonArgParserOutputExpr( + name=a.name, + expr=expr, + index=arg_index, + argument=a, + ) + + +# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. +def arg_parser_output_exprs( + ps: PythonSignature, f: NativeFunction, *, symint: bool = True +) -> dict[str, PythonArgParserOutputExpr]: + return { + e.name: e + for i, a in enumerate(ps.arguments()) + for e in (arg_parser_output_expr(i, a, symint=symint),) + } + + +# argument name to type for scattered tensor options fields +TENSOR_OPTIONS_FIELDS = { + "dtype": "ScalarType?", + "device": "Device?", + "layout": "Layout?", + "pin_memory": "bool?", + "requires_grad": "bool?", +} + + +# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args). +def dispatch_lambda_exprs( + ps: PythonSignature, f: NativeFunction, *, symint: bool = True +) -> DispatchLambdaArgumentExprs: + # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing + # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser + # outputs. + arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_args = dispatch_lambda_args(ps, f, symint=symint) + inits: list[str] = [] + lambda_args_exprs: dict[str, str] = {} + + has_toptions = has_tensor_options(f) + + # 1. special inits/unpacking to provide binding exprs for lambda arguments. + for a in ps.arguments(skip_tensor_options=True): + name = a.name + arg_parser_expr = arg_parser_outputs[a.name].expr + + if has_toptions and name == "self": + # TODO: why this needs to be special case? + inits.extend( + [ + f"auto self = {arg_parser_expr};", + ] + ) + lambda_args_exprs[name] = name + elif ( + isinstance(a, PythonOutArgument) + and len(a.outputs) > 1 + and f.func.is_out_fn() + ): + inits.extend( + [ + f"auto out = {arg_parser_expr};", + ] + ) + for i, out_arg in enumerate(a.outputs): + lambda_args_exprs[out_arg.name] = f"out[{i}]" + elif str(a.type) == "Dimname[]?": + # [old codegen] + # TODO: make this part of something more general, or get rid of it. + # optional> are special. The PythonArgParser returns an + # optional>, which cannot be implicitly converted to + # optional>. One needs to unwrap the optional and rewrap. + inits.extend( + [ + f"auto __{name} = {arg_parser_expr};", + f"::std::optional {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950 + ] + ) + lambda_args_exprs[name] = name + else: + # default case - directly using PythonArgParser output expr + lambda_args_exprs[name] = arg_parser_expr + + # method's self is passed directly to python binding, rather than parsed + if ps.method: + lambda_args_exprs["self"] = "self" + + # 2. special packing/checking for TensorOptions. + tensor_options_args_names = [a.name for a in ps.tensor_options_args] + if has_toptions: + if f.func.is_out_fn(): + raise RuntimeError(f"{f.func}: tensor options with output arg") + for a in ps.tensor_options_args: + if a.name not in TENSOR_OPTIONS_FIELDS: + raise RuntimeError( + f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments" + ) + if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): + raise RuntimeError( + f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'" + ) + if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS): + raise RuntimeError( + f"{f.func}: incomplete tensor options args: {tensor_options_args_names}" + ) + + inits.append( + f"""\ +const auto options = TensorOptions() + .dtype({arg_parser_outputs["dtype"].expr}) + .device({arg_parser_outputs["device"].expr}) + .layout({arg_parser_outputs["layout"].expr}) + .requires_grad({arg_parser_outputs["requires_grad"].expr}) + .pinned_memory({arg_parser_outputs["pin_memory"].expr}); +torch::utils::maybe_initialize_device(options); +""" + ) + lambda_args_exprs["options"] = "options" + + # 3. special case - access scattered TensorOptions fields without packing + # TODO: maybe move to the generator side as it's not related to binding. + if not has_toptions and tensor_options_args_names: + if "dtype" in tensor_options_args_names: + # we're an output-arg variant, check these args against output tensor + if not f.func.is_out_fn(): + raise RuntimeError( + f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}" + ) + if not all(a in tensor_options_args_names for a in ("layout", "device")): + raise RuntimeError( + f"{f.func}: incomplete tensor options for output check" + ) + + inits.append( + f"""\ +check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr}, + {arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr}, + {arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr}); +""" + ) + # we'll set requires_grad on outgoing tensor + if "requires_grad" not in tensor_options_args_names: + raise RuntimeError( + f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]' + ) + + return DispatchLambdaArgumentExprs( + exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args), + inits=inits, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e98bb8df493f2375b514e6c6aeb897cebe8ec7d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__init__.py @@ -0,0 +1,5 @@ +from torchgen.api.types.types import * +from torchgen.api.types.types_base import * + + +from torchgen.api.types.signatures import * # usort: skip diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b289118451028a9cff74da0ac57c469678309429 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/signatures.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/signatures.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c049dc3bdecc07f87a506dc706b7ecec9d5a9da Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/signatures.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeba8a83e4113845b4a83b7212a81bd9eaf84cc5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/types_base.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/types_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e257670bfc696b30bb29e3699103ba5b83ddddfe Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/__pycache__/types_base.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/signatures.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/signatures.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a47536dd1ff213bc8bd8aceee2bd22531088a6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/signatures.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from torchgen.api.types.types_base import Binding, CType, Expr + + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + from torchgen.model import ( + BackendIndex, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + ) + + +@dataclass(frozen=True) +class CppSignature: + """ + A CppSignature represents a single overload in the C++ API. For + any given function schema, there may be multiple CppSignatures + corresponding to it, based on how we desugar to C++. See also + CppSignatureGroup. + """ + + # The schema this signature is derived from + func: FunctionSchema + + # Is this a C++ signature for a method, i.e. Tensor::my_op(...)? + method: bool + + # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API + # (i.e. with a potential TensorOptions argument and out arguments in the front) + faithful: bool + + # Is this a symint C++ signature. For BC reasons, functions that take + # SymInts still present as int64_t in C++, and the SymInt variant is + # offered at a different overload name + # + # NB: If a function RETURNS a SymInt, this is ALWAYS false + symint: bool + + # The set of C++ arguments which should not have defaults applied to them + cpp_no_default_args: set[str] + + # Is this a fallback C++ binding? Fallback bindings are enabled by + # manual_cpp_binding: True and are alternate, non-public API that + # lets manual C++ binding implementers access the binding that would + # have been automatically generated + fallback_binding: bool = False + + # Return the unpacked argument structure of this signature, + # discarding information about which arguments are semantically + # related to each other. + def arguments(self) -> Sequence[Binding]: + return cpp.arguments( + self.func.arguments, + faithful=self.faithful, + symint=self.symint, + method=self.method, + cpp_no_default_args=self.cpp_no_default_args, + ) + + def name(self, *, suppress_symint_suffix: bool = False) -> str: + n = cpp.name( + self.func, + faithful_name_for_out_overloads=self.faithful, + symint_overload=False if suppress_symint_suffix else self.symint, + ) + if self.fallback_binding: + n = f"__dispatch_{n}" + return n + + # Render the C++ declaration for this signature + def decl( + self, + *, + name: str | None = None, + prefix: str = "", + is_redispatching_fn: bool = False, + suppress_symint_suffix: bool = False, + ) -> str: + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() + cpp_args = [a.decl() for a in self.arguments()] + if is_redispatching_fn: + cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args + cpp_args_str = ", ".join(cpp_args) + if name is None: + name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix) + return f"{returns_type} {name}({cpp_args_str})" + + # Render the C++ definition for this signature, not including + # the body (with curly braces) + def defn( + self, + *, + name: str | None = None, + prefix: str = "", + is_redispatching_fn: bool = False, + ) -> str: + returns_type = cpp.returns_type( + self.func.returns, symint=self.symint + ).cpp_type() + cpp_args = [a.defn() for a in self.arguments()] + if is_redispatching_fn: + cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args + cpp_args_str = ", ".join(cpp_args) + if name is None: + name = prefix + self.name() + return f"{returns_type} {name}({cpp_args_str})" + + def ptr_type(self) -> str: + args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})" + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})" + + +# Represents group of all CppSignatures associated with a +# FunctionSchema. Right now, that's the regular, user-visible +# signature, as well as a "faithful" signature which doesn't +# have grouping. +@dataclass(frozen=True) +class CppSignatureGroup: + func: FunctionSchema + signature: CppSignature + faithful_signature: CppSignature | None + symint_signature: CppSignature | None + symint_faithful_signature: CppSignature | None + + def most_faithful_signature(self) -> CppSignature: + if self.faithful_signature: + return self.faithful_signature + else: + return self.signature + + def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]: + yield self.signature + if self.faithful_signature: + yield self.faithful_signature + if symint: + if self.symint_signature: + yield self.symint_signature + if self.symint_faithful_signature: + yield self.symint_faithful_signature + + @staticmethod + def from_native_function( + f: NativeFunction, *, method: bool, fallback_binding: bool = False + ) -> CppSignatureGroup: + func = f.func + + def make_sig(*, faithful: bool, symint: bool) -> CppSignature: + return CppSignature( + func=func, + faithful=faithful, + symint=symint, + method=method, + fallback_binding=fallback_binding, + cpp_no_default_args=f.cpp_no_default_args, + ) + + def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]: + faithful_signature: CppSignature | None = None + if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: + faithful_signature = make_sig(faithful=True, symint=symint) + signature = make_sig(faithful=False, symint=symint) + return signature, faithful_signature + + signature, faithful_signature = make_sigs(symint=False) + symint_signature: CppSignature | None = None + symint_faithful_signature: CppSignature | None = None + if func.has_symint(): + symint_signature, symint_faithful_signature = make_sigs(symint=True) + + return CppSignatureGroup( + func=func, + signature=signature, + faithful_signature=faithful_signature, + symint_signature=symint_signature, + symint_faithful_signature=symint_faithful_signature, + ) + + +@dataclass(frozen=True) +class DispatcherSignature: + # The schema this signature is derived from + func: FunctionSchema + + # Allows you to prepend an arbitrary prefix to the signature name. + # This is useful for parts of the codegen that generate wrappers around kernels, + # and need to avoid naming collisions. + prefix: str = "" + + symint: bool = True + + def arguments(self) -> list[Binding]: + return dispatcher.arguments(self.func, symint=self.symint) + + def name(self) -> str: + return self.prefix + dispatcher.name(self.func) + + def decl(self, name: str | None = None) -> str: + args_str = ", ".join(a.decl() for a in self.arguments()) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def defn( + self, name: str | None = None, *, is_redispatching_fn: bool = False + ) -> str: + args = [a.defn() for a in self.arguments()] + if is_redispatching_fn: + args = ["c10::DispatchKeySet dispatchKeySet"] + args + args_str = ", ".join(args) + if name is None: + name = self.name() + return f"{self.returns_type().cpp_type()} {name}({args_str})" + + def exprs(self) -> list[Expr]: + return [Expr(a.name, a.nctype) for a in self.arguments()] + + def returns_type(self) -> CType: + return dispatcher.returns_type(self.func.returns, symint=self.symint) + + def ptr_type(self) -> str: + dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})" + + # Return the C++ function type, e.g., something like int(bool) + def type(self) -> str: + dispatcher_args_types_str = ", ".join(a.type for a in self.arguments()) + return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})" + + @staticmethod + def from_schema( + func: FunctionSchema, *, prefix: str = "", symint: bool = True + ) -> DispatcherSignature: + return DispatcherSignature(func, prefix, symint) + + +@dataclass(frozen=True) +class NativeSignature: + # The schema this signature is derived from + func: FunctionSchema + + symint: bool + + prefix: str = "" + + def name(self) -> str: + return self.prefix + native.name(self.func) + + def decl(self, name: str | None = None) -> str: + args_str = ", ".join(a.decl() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" + + def defn(self, name: str | None = None) -> str: + args_str = ", ".join(a.defn() for a in self.arguments()) + if name is None: + name = self.name() + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" + + def ptr_type(self) -> str: + # don't include defaults in type signature! + args_str = ", ".join(a.defn() for a in self.arguments()) + return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" + + def arguments(self) -> list[Binding]: + return native.arguments(self.func, symint=self.symint) + + def returns_type(self) -> CType: + return native.returns_type(self.func.returns, symint=self.symint) + + def dispatcher_exprs(self) -> list[Expr]: + return translate.translate( + self.arguments(), dispatcher.arguments(self.func), method=False + ) + + +@dataclass(frozen=True) +class ViewInverseSignature: + g: NativeFunctionsViewGroup + + def name(self) -> str: + return functionalization.reverse_name(self.g.view, include_namespace=False) + + def decl(self) -> str: + return_type = functionalization.returns_type(self.g.view.func) + decls = [ + a.decl() + for a in functionalization.op_arguments(self.g.view.func, is_reverse=True) + ] + return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});" + + +@dataclass(frozen=True) +class StructuredImplSignature: + g: NativeFunctionsGroup + name: str + + def defn(self, name: str | None = None) -> str: + args_str = ", ".join(a.defn() for a in self.arguments()) + return f"TORCH_IMPL_FUNC({self.name})({args_str})" + + def arguments(self) -> list[Binding]: + return structured.impl_arguments(self.g) + + +# Helper functions + + +def kernel_signature( + f: NativeFunction, backend_index: BackendIndex, *, prefix: str = "" +) -> NativeSignature | DispatcherSignature: + # Note [External Backends Follow Dispatcher API] + # Kernel signatures for in-tree backends follow the "native" API, + # while kernels for out-of-tree backends follow the dispatcher API. + # See the comments in `native.py` for details, but historically there have been + # some small differences in schema convention between them and the Dispatcher API. + # Any differences that require translating between the two will results in a runtime cost, + # so we'd like to keep the differences as small as possible. + # With external backends, we'd like to enforce that they write their kernels with schemas + # that match the Dispatcher API directly, if they can. + meta = backend_index.get_kernel(f) + symint = meta is not None and meta.supports_symint() + if symint: + assert f.func.has_symint(), ( + f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" + ) + if backend_index.external: + return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) + else: + return NativeSignature(f.func, prefix=prefix, symint=symint) + + +# Functions only, no types +from torchgen.api import ( + cpp, + dispatcher, + functionalization, + native, + structured, + translate, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/types.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/types.py new file mode 100644 index 0000000000000000000000000000000000000000..41c05653fffdf3d04fc7078e7df142124ed96e00 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/types.py @@ -0,0 +1,183 @@ +""" +Where should I add a new type? `types_base.py` vs `types.py` + +This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. + +`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. + +The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't +contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused +if we want to generate code for another C++ library. + +Add new types to `types.py` if these types are ATen/c10 related. +Add new types to `types_base.py` if they are basic and not attached to ATen/c10. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from torchgen.api.types.types_base import ( + BaseCppType, + BaseCType, + boolT, + byteT, + charT, + CType, + doubleT, + floatT, + int32T, + longT, + shortT, +) +from torchgen.model import BaseTy, ScalarType + + +TENSOR_LIST_LIKE_CTYPES = [ + "at::TensorList", + "const c10::List<::std::optional> &", + "const at::ITensorListRef &", +] + + +halfT = BaseCppType("at", "Half") +complexHalfT = BaseCppType( + "c10", "complex" +) # stuffing template param here is an abuse +complexFloatT = BaseCppType("c10", "complex") +complexDoubleT = BaseCppType("c10", "complex") +bfloat16T = BaseCppType("at", "BFloat16") +float8_e5m2T = BaseCppType("at", "Float8_e5m2") +float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz") +float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn") +float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz") +float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu") +stringT = BaseCppType("c10", "string_view") +generatorT = BaseCppType("at", "Generator") +scalarTypeT = BaseCppType("at", "ScalarType") +tensorT = BaseCppType("at", "Tensor") +optionalTensorRefT = BaseCppType("at", "OptionalTensorRef") +tensorListT = BaseCppType("at", "TensorList") +iTensorListRefT = BaseCppType("at", "ITensorListRef") +iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef") +dimnameT = BaseCppType("at", "Dimname") +dimnameListT = BaseCppType("at", "DimnameList") +dimVectorT = BaseCppType("at", "DimVector") +layoutT = BaseCppType("at", "Layout") +deviceT = BaseCppType("at", "Device") +deviceIndexT = BaseCppType("at", "DeviceIndex") +scalarT = BaseCppType("at", "Scalar") +optionalScalarRefT = BaseCppType("at", "OptionalScalarRef") +memoryFormatT = BaseCppType("at", "MemoryFormat") +qschemeT = BaseCppType("at", "QScheme") +storageT = BaseCppType("at", "Storage") +streamT = BaseCppType("at", "Stream") +intArrayRefT = BaseCppType("at", "IntArrayRef") +optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef") +optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef") +tensorOptionsT = BaseCppType("at", "TensorOptions") +typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize") +tensorGeometryT = BaseCppType("at", "TensorGeometry") +SymIntT = BaseCppType("c10", "SymInt") +SymBoolT = BaseCppType("c10", "SymBool") +symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef") + +# Types representing template parameters. Technically, we probably shouldn't +# represent them this way in codegen, but it was pretty convenient. +scalar_t = BaseCppType("", "scalar_t") +opmath_t = BaseCppType("", "opmath_t") + +ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { + ScalarType.Byte: byteT, + ScalarType.Char: charT, + ScalarType.Short: shortT, + ScalarType.Int: int32T, + ScalarType.Long: longT, + ScalarType.Half: halfT, + ScalarType.Float: floatT, + ScalarType.Double: doubleT, + ScalarType.ComplexHalf: complexHalfT, + ScalarType.ComplexFloat: complexFloatT, + ScalarType.ComplexDouble: complexDoubleT, + ScalarType.Bool: boolT, + ScalarType.Float8_e5m2: float8_e5m2T, + ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT, + ScalarType.Float8_e4m3fn: float8_e4m3fnT, + ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, + ScalarType.Float8_e8m0fnu: float8_e8m0fnuT, +} + +BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { + BaseTy.int: longT, + BaseTy.float: doubleT, + BaseTy.bool: boolT, + BaseTy.str: stringT, + BaseTy.Generator: generatorT, + BaseTy.ScalarType: scalarTypeT, + BaseTy.Tensor: tensorT, + BaseTy.Dimname: dimnameT, + BaseTy.DimVector: dimVectorT, + BaseTy.Layout: layoutT, + BaseTy.Device: deviceT, + BaseTy.DeviceIndex: deviceIndexT, + BaseTy.Scalar: scalarT, + BaseTy.MemoryFormat: memoryFormatT, + BaseTy.QScheme: qschemeT, + BaseTy.Storage: storageT, + BaseTy.Stream: streamT, + BaseTy.SymInt: SymIntT, + BaseTy.SymBool: SymBoolT, +} + +# CTypes encode C++ type structure as needed for translation. + + +@dataclass(frozen=True) +class OptionalCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::optional<{self.elem.cpp_type()}>" + + def remove_const_ref(self) -> CType: + return OptionalCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ListCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"c10::List<{self.elem.cpp_type()}>" + + def remove_const_ref(self) -> CType: + return ListCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"at::ArrayRef<{self.elem.cpp_type()}>" + + def remove_const_ref(self) -> CType: + return ArrayRefCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class VectorizedCType(CType): + # This template is explicitly specialized, so the only valid + # elems are those we have specializations for (e.g., float, double, ...) + # scalar_t is also a common argument here (when we are codegen in + # a templated context) + elem: BaseCType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return f"at::vec::Vectorized<{self.elem.cpp_type()}>" + + def remove_const_ref(self) -> CType: + return self diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/types_base.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/types_base.py new file mode 100644 index 0000000000000000000000000000000000000000..08085fa0fa2bf04b3be6d9a9b8c411c9bbfed6d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/api/types/types_base.py @@ -0,0 +1,238 @@ +""" +Where should I add a new type? `types_base.py` vs `types.py` + +This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. + +`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. + +The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't +contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused +if we want to generate code for another C++ library. + +Add new types to `types.py` if these types are ATen/c10 related. +Add new types to `types_base.py` if they are basic and not attached to ATen/c10. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import auto, Enum +from typing import TYPE_CHECKING, Union + + +if TYPE_CHECKING: + from torchgen.model import Argument, SelfArgument, TensorOptionsArguments + + +# An ArgName is just the str name of the argument in schema; +# but in some special circumstances, we may add a little extra +# context. The Enum SpecialArgName covers all of these cases; +# grep for their construction sites to see when they can occur. + + +class SpecialArgName(Enum): + possibly_redundant_memory_format = auto() + + +ArgName = Union[str, SpecialArgName] + + +# This class shouldn't be created directly; instead, use/create one of the singletons below. +@dataclass(frozen=True) +class BaseCppType: + ns: str | None + name: str + + def __str__(self) -> str: + if self.ns is None or self.ns == "": + return self.name + return f"{self.ns}::{self.name}" + + +# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen. +# Templated types get their own dataclass, mainly to make namespace parsing easier. +byteT = BaseCppType("", "uint8_t") +charT = BaseCppType("", "int8_t") +shortT = BaseCppType("", "int16_t") +# It would be more symmetric for this to be called intT, but it easy to mix +# this up with JIT int (which is int64_t in C++), so we intentionally don't +# define intT to make it obvious when you've stuffed it up +int32T = BaseCppType("", "int32_t") +longT = BaseCppType("", "int64_t") +doubleT = BaseCppType("", "double") +floatT = BaseCppType("", "float") +boolT = BaseCppType("", "bool") +voidT = BaseCppType("", "void") + + +class CType(ABC): + @abstractmethod + def cpp_type(self, *, strip_ref: bool = False) -> str: + raise NotImplementedError + + @abstractmethod + def remove_const_ref(self) -> CType: + return self + + +@dataclass(frozen=True) +class BaseCType(CType): + type: BaseCppType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return str(self.type) + + def remove_const_ref(self) -> CType: + return self + + +@dataclass(frozen=True) +class ConstRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f"const {self.elem.cpp_type()} &" + + def remove_const_ref(self) -> CType: + return self.elem.remove_const_ref() + + +@dataclass(frozen=True) +class VectorCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::vector<{self.elem.cpp_type()}>" + + def remove_const_ref(self) -> CType: + return VectorCType(self.elem.remove_const_ref()) + + +@dataclass(frozen=True) +class ArrayCType(CType): + elem: CType + size: int + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::array<{self.elem.cpp_type()},{self.size}>" + + def remove_const_ref(self) -> CType: + return ArrayCType(self.elem.remove_const_ref(), self.size) + + +@dataclass(frozen=True) +class TupleCType(CType): + elems: list[CType] + + def cpp_type(self, *, strip_ref: bool = False) -> str: + # Do not pass `strip_ref` recursively. + return f"::std::tuple<{','.join([e.cpp_type() for e in self.elems])}>" + + def remove_const_ref(self) -> CType: + return TupleCType([e.remove_const_ref() for e in self.elems]) + + +@dataclass(frozen=True) +class MutRefCType(CType): + elem: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + if strip_ref: + return self.elem.cpp_type(strip_ref=strip_ref) + return f"{self.elem.cpp_type()} &" + + def remove_const_ref(self) -> CType: + return self.elem.remove_const_ref() + + +# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus +# semantic information about what it represents. For example, consider the +# argument "bool pin_memory"; its normal C++ type is "bool", but its C++ +# semantic type also keeps track that this represents a "pin_memory"; you can't +# just use a random other boolean in a context where you need a "pin_memory"! +# + + +@dataclass(frozen=True) +class NamedCType: + name: ArgName + type: CType + + def cpp_type(self, *, strip_ref: bool = False) -> str: + return self.type.cpp_type(strip_ref=strip_ref) + + def remove_const_ref(self) -> NamedCType: + return NamedCType(self.name, self.type.remove_const_ref()) + + def with_name(self, name: str) -> NamedCType: + return NamedCType(name, self.type) + + +# A binding represents any C++ binding site for a formal parameter. +# We don't distinguish between binding sites for different APIs; +# instead, all of the important distinctions are encoded in CType, +# which you can use to figure out if a given Binding is appropriate +# for use in another context. (See torchgen.api.translate) + + +@dataclass(frozen=True) +class Binding: + name: str + nctype: NamedCType + argument: Argument | TensorOptionsArguments | SelfArgument + # TODO: maybe don't represent default here + default: str | None = None + + def rename(self, name: str) -> Binding: + return Binding( + name=name, + nctype=self.nctype, + argument=self.argument, + default=self.default, + ) + + @property + def type(self) -> str: + return self.nctype.cpp_type() + + def no_default(self) -> Binding: + return Binding( + name=self.name, + nctype=self.nctype, + default=None, + argument=self.argument, + ) + + def decl(self, *, func_ptr_cast: bool = False) -> str: + mb_default = "" + if self.default is not None: + mb_default = f"={self.default}" + + # casting only needs to know the type + if func_ptr_cast: + return f"{self.type}" + else: + return f"{self.type} {self.name}{mb_default}" + + def defn(self) -> str: + return f"{self.type} {self.name}" + + def with_name(self, name: str) -> Binding: + return Binding( + name=name, nctype=self.nctype, argument=self.argument, default=self.default + ) + + +# An Expr is a C++ expression. It has a C++ string representing its syntax, +# as well as a CType saying what it provides. + + +@dataclass(frozen=True) +class Expr: + expr: str + type: NamedCType diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f83363c3f1cd43e7834200d3838bf7137c4aeb02 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95ac79c651a96a973bb853fa36bf4e734893f224 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4847a345898befd644756d1c46fe4fec3d8c21f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/native_functions.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/native_functions.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6697542214f5596e0bc0db7a3ac811bb6fcee1f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/native_functions.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6abe08c91f25eca7e376d7253ab1e71a65c26c50 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/ufunc.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/ufunc.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ff64e0453ceb33b9e8f76d675e41ffbb64dd08d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/dest/__pycache__/ufunc.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63e8c0288b5111004d655c84a03cf0fd33a307f1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db9208956d7668ee727e8c072e736175598e40c1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295cfa64899a5076107274f56dd7640c095f355d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..35f43297fdd9ca9f932c8c53b5b773f1b9b8a427 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +${inline_headers} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..fbb71c2cb123cb21fb57ec32341d86bff06f6a17 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h @@ -0,0 +1,22 @@ +#pragma once +// ${generated_comment} + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +${DispatchKeyFunctions_inl_includes} + + +${dispatch_namespaced_declarations} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7647f459a744b2eacfac6aaea4f49b86babbb234 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp @@ -0,0 +1,13 @@ +// ${generated_comment} +${includes} +${native_functions_include} + +namespace { +${helper_fns} +} // namespace + +${namespace_prologue} + +${native_function_definitions} + +${namespace_epilogue} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Function.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Function.h new file mode 100644 index 0000000000000000000000000000000000000000..73096afbf11571cbe4147bb63f035a054ca842db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Function.h @@ -0,0 +1,27 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +${static_dispatch_ops_headers} + +${operator_includes} + +namespace at { + +${function_definitions} + +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h new file mode 100644 index 0000000000000000000000000000000000000000..b15cd09a6c65da3127be8245b87bff2f8c795a3d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h @@ -0,0 +1,23 @@ +#pragma once + +// ${generated_comment} + +#include +#include + +namespace at { +namespace functionalization { + +struct FunctionalInverses { + +${view_inverse_declarations} + +// NB: These are not generated! They're manually implemented in the template. +// TODO: Change codegen to generate these. See the following link: +// https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585 +static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim); +static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length); + +}; +} +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/LazyIr.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/LazyIr.h new file mode 100644 index 0000000000000000000000000000000000000000..9190ff8243d316fd2bd472bb3f0603701761bdb7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/LazyIr.h @@ -0,0 +1,19 @@ +#pragma once + +// This file contains autogenerated LazyTensor IR nodes +${lazy_ir_sysinc} +${lazy_ir_inc} + +${namespace_prologue} +using at::operator<<; + +// kNullValue is used to contribute a static hash value any time +// a node has an Optional input that is nullopt. It is important +// to differentiate between HASH(std::nullopt, something) and HASH(something, std::nullopt), +// and using kNullValue in the hash function in the order of arguments +// serves this purpose. +static const torch::lazy::Value kNullValue = torch::lazy::Value(); + +${ir_declarations} + +${namespace_epilogue} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h new file mode 100644 index 0000000000000000000000000000000000000000..18eaf6da52e4b3654becac6cc89849bc0806ae09 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h @@ -0,0 +1,11 @@ +#pragma once + +${lazy_non_native_ir_inc} + +// This file contains autogenerated LazyTensor Non Native IR nodes + +${namespace_prologue} + +${non_native_ir_nodes} + +${namespace_epilogue} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..a5441ad85d1d5e28c4e31dd3f0dc7f66dfbff9e7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h @@ -0,0 +1,17 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +${extra_includes} + +${native_function_declarations} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..9dc972495ca038bddb7b887c39c2e0507e487213 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h @@ -0,0 +1,33 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +${NativeFunctions_includes} + +${NativeFunctions_declarations} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h new file mode 100644 index 0000000000000000000000000000000000000000..6522c97546d0498e4b3825fb4eafefbb34c71911 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h @@ -0,0 +1,23 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace meta { + +${meta_function_declarations} + +} // namespace native +} // namespace at diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operator.h new file mode 100644 index 0000000000000000000000000000000000000000..ed220f917290c2062481eb53dca232b47d180e2d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operator.h @@ -0,0 +1,19 @@ +#pragma once + +// ${generated_comment} + +#include +#include +#include + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +namespace at { +namespace _ops { + +${declarations} + +}} // namespace at::_ops diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operators.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operators.cpp new file mode 100644 index 0000000000000000000000000000000000000000..082bb67c3e2043f2c36b29345f57048ec2e9eea7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operators.cpp @@ -0,0 +1,19 @@ +#include +#include + +// ${generated_comment} +// NOTE See [Sharded File] comment in VariableType + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +${operator_headers} +#endif + +${static_dispatch_extra_headers} + +namespace at { namespace _ops { + +${definitions} + +}} // namespace at::_ops diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operators.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operators.h new file mode 100644 index 0000000000000000000000000000000000000000..e74b96ef3d5c6b6d50fe63eac4dca51f0655daa5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/Operators.h @@ -0,0 +1,74 @@ +#pragma once + +// ${generated_comment} + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +${Operators_includes} + +// Extension writers: do you write wrapper functions? Are you frustrated with +// resolving overloads of operators? Are you frustrated with dealing with +// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no +// further, this is the utility for you. +// +// Given an operator schema: aten::op.overload(... +// +// Use ATEN_FN2(op, overload) to get a *function* version of the operator +// that is guaranteed to not be overloaded. This means that you can safely +// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args. +// +// Given an operator schema without an overload name: aten::op(... +// +// Use ATEN_FN(op) to get an unambiguous *function* version of the operator. +// +// There is some interesting behavior for out= operations. +// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema; +// that is, the order of arguments is exactly what it looks like in the schema. + +#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call +#define ATEN_FN(op_name) at::_ops::op_name::call + +// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time +// metadata about a given aten operator. +// Notable data on the class includes: +// - ATEN_OP2(add, Tensor)::name // returns the string name: "add" +// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor" +// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &) +// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + +#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload +#define ATEN_OP(op_name) at::_ops::op_name + +// WARNING: Please do not call any of the ops in the _ops namespace directly. +// Use the ATEN_FN macros. We do not guarantee stability of the naming +// scheme for the functions in at::_ops + +// See Note [The ATen Operators API] for details of the at::_ops namespace + +namespace at { +namespace _ops { +${Operators_declarations} +} // namespace _ops +} // namespace at diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp new file mode 100644 index 0000000000000000000000000000000000000000..018cf358f11237d5bdc9bca01aa8d09d1462f574 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp @@ -0,0 +1,29 @@ +// We register ops with a higher priority dispatch key (BackendSelect) than the usual backend-specific keys (e.g. CPU) +// which makes calls to the factory functions dispatch to here. +// We then 'manually' compute a lower-priority to re-dispatch to (e.g. CPU) to get to the eventually correct backend. +// ${generated_comment} + +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else + +${ops_headers} +#endif + +namespace at { + +namespace { + +${backend_select_method_definitions} + +TORCH_LIBRARY_IMPL(aten, BackendSelect, m) { + ${backend_select_function_registrations}; +} + +} // namespace +} // at diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini new file mode 100644 index 0000000000000000000000000000000000000000..97c921de18f62832d1ca09c245f2466541fe908d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini @@ -0,0 +1,22 @@ +${ns_prologue} + +// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid +// ambiguity with conflicting identifiers that may have been defined in +// at namespace already. +namespace { + +${dispatch_anonymous_definitions} + +${static_init_dispatch_registrations} + +} // anonymous namespace + +${deferred_dispatch_registrations} + +namespace ${dispatch_namespace} { + +${dispatch_namespaced_definitions} + +} // namespace ${dispatch_namespace} + +${ns_epilogue} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp new file mode 100644 index 0000000000000000000000000000000000000000..408aff0cdab40461a7ba731bab216a7b7435331e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp @@ -0,0 +1,116 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// ${generated_comment} + +#include +#include +#include +#include +#include +#include + +#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +// needed for the meta tensor calls to get stride info in functionalization +#include +// needed for special handling of copy_(). +// See Note [functionalizating copy_() and not preserving strides] +#include +#include + +$ops_headers +#endif + +namespace at { +namespace functionalization { + +// This keyset is used by functionalization when it calls into meta kernels +// to accurately propagate stride metadata. +// Exclude any modes: the purpose of calling into meta kernels is only as an implementation +// detail to perform shape inference, and we don't want any modal keys to run. +// Specifically, we want to prevent functionalization and Python modes from running. +constexpr auto exclude_keys_for_meta_dispatch = + c10::functorch_transforms_ks | + c10::DispatchKeySet({ + c10::DispatchKey::FuncTorchDynamicLayerBackMode, + c10::DispatchKey::FuncTorchDynamicLayerFrontMode, + c10::DispatchKey::Python, + c10::DispatchKey::PreDispatch, + + }); + +// Helper around at::has_internal_overlap. +// The ATen util is used in hot-path eager mode: it's always fast, +// but might return TOO_HARD sometimes. +// During functionalization, we're ok taking a bit longer +// to detect memory overlap. +inline bool has_internal_overlap_helper(const at::Tensor t) { + auto has_overlap = at::has_internal_overlap(t); + if (has_overlap == at::MemOverlap::Yes) return true; + if (has_overlap == at::MemOverlap::No) return false; + return false; +} + + +inline Tensor to_meta(const Tensor& t) { + if (!t.defined()) return t; + return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(), +/*dtype=*/t.scalar_type(), /*layout=*/t.layout(), +/*device=*/c10::Device(kMeta), /*pin_memory=*/std::nullopt); +} + +inline std::optional to_meta(const std::optional& t) { + if (t.has_value()) { + return to_meta(*t); + } + return std::nullopt; +} + +inline std::vector to_meta(at::ITensorListRef t_list) { + std::vector outputs; + outputs.reserve(t_list.size()); + for (const auto& tensor : t_list) { + outputs.push_back(to_meta(tensor)); + } + return outputs; +} + +inline c10::List to_meta(const c10::List& t_list) { + c10::List outputs; + outputs.reserve(t_list.size()); + for (const auto i : c10::irange(t_list.size())) { + outputs.push_back(to_meta(t_list[i])); + } + return outputs; +} + +inline c10::List<::std::optional> to_meta(const c10::List<::std::optional>& t_list) { + c10::List<::std::optional> outputs; + outputs.reserve(t_list.size()); + for (const auto i : c10::irange(t_list.size())) { + outputs.push_back(to_meta(t_list[i])); + } + return outputs; +} + +static bool disable_meta_reference() { + static auto env = c10::utils::get_env("TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE"); + return env == "1"; +} + + +${func_definitions} + +} // namespace functionalization + +namespace { + +TORCH_LIBRARY_IMPL(aten, Functionalize, m) { + ${func_registrations}; +} + +} // namespace + +} // namespace at diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..a65469a9b0123cbfd4075ff3c263276aa47f137f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h @@ -0,0 +1,32 @@ +// ${generated_comment} + +// Generated by tools/jit/gen_unboxing.py. This file declares code generated boxed C++ functions for operators, +// base off of native_functions.yaml (or similar yaml file with the same syntax). The definition of such a boxed +// function will pop out IValues from the stack then convert them into the correct C++ types based on given schema. This +// unboxing logic is an alternative to template-based metaprogramming unboxing. + +#pragma once + +#include +namespace at { +namespace unboxing { +namespace { + +template +std::array as_array(const c10::List& list) { + std::array res; + AT_ASSERT(list.size() == N); + std::vector vec; + for (c10::IValue elem : list) { + vec.push_back(elem.to()); + } + std::copy(vec.begin(), vec.end(), res.begin()); + return res; +} +} // namespace +using Stack = std::vector; +// Generated function declaration +${declarations} + +} // namespace unboxing +} // namespace at diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/ViewMetaClasses.cpp b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/ViewMetaClasses.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fd53171935f9147ba54bcd39a886e2f4dda6b2f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/ViewMetaClasses.cpp @@ -0,0 +1,19 @@ +// ${generated_comment} + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +${op_headers} +#endif + +namespace at { +namespace functionalization { + +${view_meta_implementations} + +} // namespace functionalization +} // namespace at diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..326d4622334a776f4f1f94fb49a70f2c53c7e6eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h @@ -0,0 +1,22 @@ +#pragma once + +// ${generated_comment} + +#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if including for \ + the c10::Symbol class would be sufficient, or if your change would be \ + better placed in another file. +#endif + +// ATen symbols correspond exactly to operators defined in ATen. Every +// symbol here corresponds exactly to an ATen operation defined in +// native_functions.yaml; attributes are in one-to-one correspondence +// with their ATen name. + +#define FORALL_ATEN_BASE_SYMBOLS(_) \ +${aten_symbols} + +#define FORALL_ATTR_BASE_SYMBOLS(_) \ +${attr_symbols} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/enum_tag.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/enum_tag.h new file mode 100644 index 0000000000000000000000000000000000000000..1320fbc28ab8f7d72655816292f49a4c9a9b727d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/ATen/templates/enum_tag.h @@ -0,0 +1,10 @@ +#pragma once + +// ${generated_comment} + +namespace at { + // Enum of valid tags obtained from the entries in tags.yaml + enum class Tag { + ${enum_of_valid_tags} + }; +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/BUILD.bazel b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..d1a0db360d230fe0f027c19869c6307f17010503 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/BUILD.bazel @@ -0,0 +1,4 @@ +load("//:tools/bazel.bzl", "rules") +load(":build.bzl", "define_targets") + +define_targets(rules = rules) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/README.md b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bfa43899cc590959c2bfd74e38662ec03aaee3d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/README.md @@ -0,0 +1,3 @@ +If you add a file to this directory, you **MUST** update +`torch/CMakeLists.txt` and add the file as a dependency to +the `add_custom_command` call. diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/build.bzl b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/build.bzl new file mode 100644 index 0000000000000000000000000000000000000000..c5ddf7a20b800a714431fdc9feb57679783410f4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/build.bzl @@ -0,0 +1,20 @@ +def define_targets(rules): + rules.py_library( + name = "autograd", + srcs = rules.glob(["*.py"]), + data = rules.glob([ + "*.yaml", + "templates/*", + ]), + visibility = ["//:__subpackages__"], + deps = [ + rules.requirement("PyYAML"), + "//torchgen", + ], + ) + + rules.filegroup( + name = "deprecated_yaml", + srcs = ["deprecated.yaml"], + visibility = ["//:__subpackages__"], + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/context.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/context.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed4b2ee4d014be3dca01c3f2293b36b03b7880b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/context.py @@ -0,0 +1,31 @@ +import functools +from collections.abc import Callable + +from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI +from torchgen.context import native_function_manager +from torchgen.utils import T + + +# Like tools.api.context.with_native_function, but for +# NativeFunctionWithDifferentiabilityInfo. +def with_native_function_with_differentiability_info( + func: Callable[[NFWDI], T], +) -> Callable[[NFWDI], T]: + @functools.wraps(func) + def wrapper(f: NFWDI) -> T: + with native_function_manager(f.func): + return func(f) + + return wrapper + + +# Like the above but with an additional dispatch key string argument +def with_native_function_with_differentiability_info_and_key( + func: Callable[[NFWDI, str], T], +) -> Callable[[NFWDI, str], T]: + @functools.wraps(func) + def wrapper(f: NFWDI, key: str) -> T: + with native_function_manager(f.func): + return func(f, key) + + return wrapper diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/deprecated.yaml b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/deprecated.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52f7ec50b6ea15dae1c3308358997950d295c924 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/deprecated.yaml @@ -0,0 +1,134 @@ +# Deprecated function signatures. These are exposed in Python, but not included +# in the error message suggestions. + +- name: add(Tensor self, Scalar alpha, Tensor other) -> Tensor + aten: add(self, other, alpha) + +- name: add_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!) + aten: add_(self, other, alpha) + +- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + aten: add_out(out, self, other, alpha) + +- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor + aten: addbmm(self, batch1, batch2, beta, alpha) + +- name: addbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: addbmm_(self, batch1, batch2, beta, alpha) + +- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: addbmm_out(out, self, batch1, batch2, beta, alpha) + +- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor + aten: addbmm(self, batch1, batch2, beta, 1) + +- name: addbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: addbmm_(self, batch1, batch2, beta, 1) + +- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: addbmm_out(out, self, batch1, batch2, beta, 1) + +- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor + aten: addcdiv(self, tensor1, tensor2, value) + +- name: addcdiv_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!) + aten: addcdiv_(self, tensor1, tensor2, value) + +- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!) + aten: addcdiv_out(out, self, tensor1, tensor2, value) + +- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor + aten: addcmul(self, tensor1, tensor2, value) + +- name: addcmul_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!) + aten: addcmul_(self, tensor1, tensor2, value) + +- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!) + aten: addcmul_out(out, self, tensor1, tensor2, value) + +- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor + aten: addmm(self, mat1, mat2, beta, alpha) + +- name: addmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor(a!) + aten: addmm_(self, mat1, mat2, beta, alpha) + +- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + aten: addmm_out(out, self, mat1, mat2, beta, alpha) + +- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor + aten: addmm(self, mat1, mat2, beta, 1) + +- name: addmm_(Scalar beta, Tensor(a!) self, Tensor mat1, Tensor mat2) -> Tensor(a!) + aten: addmm_(self, mat1, mat2, beta, 1) + +- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + aten: addmm_out(out, self, mat1, mat2, beta, 1) + +- name: sspaddmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor + aten: sspaddmm(self, mat1, mat2, beta, alpha) + +- name: sspaddmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor + aten: sspaddmm(self, mat1, mat2, beta, 1) + +- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor + aten: addmv(self, mat, vec, beta, alpha) + +- name: addmv_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor(a!) + aten: addmv_(self, mat, vec, beta, alpha) + +- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + aten: addmv_out(out, self, mat, vec, beta, alpha) + +- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec) -> Tensor + aten: addmv(self, mat, vec, beta, 1) + +- name: addmv_(Scalar beta, Tensor(a!) self, Tensor mat, Tensor vec) -> Tensor(a!) + aten: addmv_(self, mat, vec, beta, 1) + +- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + aten: addmv_out(out, self, mat, vec, beta, 1) + +- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor + aten: addr(self, vec1, vec2, beta, alpha) + +- name: addr_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor(a!) + aten: addr_(self, vec1, vec2, beta, alpha) + +- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + aten: addr_out(out, self, vec1, vec2, beta, alpha) + +- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2) -> Tensor + aten: addr(self, vec1, vec2, beta, 1) + +- name: addr_(Scalar beta, Tensor(a!) self, Tensor vec1, Tensor vec2) -> Tensor(a!) + aten: addr_(self, vec1, vec2, beta, 1) + +- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + aten: addr_out(out, self, vec1, vec2, beta, 1) + +- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor + aten: baddbmm(self, batch1, batch2, beta, alpha) + +- name: baddbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: baddbmm_(self, batch1, batch2, beta, alpha) + +- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: baddbmm_out(out, self, batch1, batch2, beta, alpha) + +- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor + aten: baddbmm(self, batch1, batch2, beta, 1) + +- name: baddbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!) + aten: baddbmm_(self, batch1, batch2, beta, 1) + +- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!) + aten: baddbmm_out(out, self, batch1, batch2, beta, 1) + +- name: sub(Tensor self, Scalar alpha, Tensor other) -> Tensor + aten: sub(self, other, alpha) + +- name: sub_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!) + aten: sub_(self, other, alpha) + +- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + aten: sub_out(out, self, other, alpha) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/derivatives.yaml b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/derivatives.yaml new file mode 100644 index 0000000000000000000000000000000000000000..88e0a316f9d09c49d7ec370cff912bba59c27136 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/derivatives.yaml @@ -0,0 +1,3242 @@ +# Defines derivative formulas and Python signatures of methods on Variable +# +# Note about possibly confusing nomenclature: An 'output gradient' is the +# gradient of an output of a forward function. Output gradients are used as +# the inputs to backward functions. `grads` is a vector of output gradients, +# and `grad == grads[0]`, in all the derivative formulas in this file. +# An 'input gradient' is the gradient of an input to a forward function. +# Input gradients are the outputs of backward functions, corresponding to the +# input names included in the derivative formulas defined in this file. +# Also, every time we talk computing "gradient" we actually mean computing +# the vector jacobian product using the given 'output gradient' as the vector. +# +# Each entry consists of: +# - A 'name', which specifies the ATen name of the function you +# are defining derivatives for, and an argument specification. +# - An optional 'dispatch' entry which can be used to specify +# per-autograd dispatch key derivatives. If this entry is not +# specified, then the gradient entries will be taken as the +# default gradients (i.e. registered for every backward dispatch +# key). (see _test_autograd_multiple_dispatch for an example +# of how to register separate derivates for different dispatch keys). +# The list of allowed dispatch keys (in addition to 'Default' which +# represents the Autograd alias key) is torchgen/model.py:AUTOGRAD_KEYS. +# - One or more gradients entries, mapping differentiable input +# names to a formula specifying how to compute its gradient. +# Note that a single gradient entry can specify the gradient +# formula for multiple input names, by specifying a key +# "input1, input2" (see atan2 for an example). +# - An argument can be flagged as 'non_differentiable'. +# - Optional entry with key 'output_differentiability' and value a list of the +# same length as the number of outputs from the forward function. The list +# should contain only booleans, specifying whether each of the output Tensor +# is differentiable. +# If it is not specified for a function that returns multiple elements but +# uses `grad` instead of `grads[idx]`, then all but the first output will +# be marked as non-differentiable. +# If None of the output is differentiable, you can also add the function +# name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list. +# +# There are two cases for Tensor and TensorList arguments here: +# - If that argument is differentiable, in the sense that a gradient with respect +# to that argument could exist. You should either: +# - Specify the formula for that gradient +# - Specify not_implemented("function_name") as a formula to say that this is not +# implemented yet (but might be in the future and the user can request that on an issue) +# - If that argument is not differentiable, because it is not a floating point dtype or the +# function is not differentiable with respect to that argument for +# example. You should either: +# - Do not specify any formula for this argument +# - Specify explicitly that this argument is "non_differentiable". Note that in this case, +# we trust you that this argument will never have requires_grad=True and it will be silently +# ignored if it does. +# +# If a function has out-of-place and in-place variants, then the derivative +# definition for the in-place variant is optional. It will default to the +# definition for the out-of-place variant. Note that _out variants are never +# differentiable. +# +# Gradient expressions are standard C++ expressions operating on ATen +# variables. In a gradient expression, the following variables/functions +# are in scope: +# +# - 'grad', the gradient of the output (often spelled grad_output +# in Python) which we are going to left-multiply. +# +# When a function returns multiple *differentiable* outputs, +# you can refer to the gradients of each outputs using 'grads', +# e.g., 'grads[0]', 'grads[1]'. +# +# When a function returns multiple *differentiable* outputs that +# are named, you can refer to the gradients of each outputs using +# 'grad_{name}', e.g., 'grad_x', 'grad_y'. +# +# When a function returns *one* differentiable output (the +# first output) and some more nondifferentiable outputs, +# you MUST refer to the gradient of the differentiable output with +# 'grad' (this case is special-cased in our code generation). +# +# Note that the number of differentiable outputs can be modified by the +# 'output_differentiability' entry (see above). +# +# Across a differentiable function's derivatives set, it is not +# permitted to mix the use of "grad", "grads", and +# "grad_{name}". You must be consistent for that differentiable +# function. +# +# - Any of the input arguments, tensor or non-tensor, including +# argument names that only appear in Declarations.yaml, e.g. 'output'. +# +# - 'result', representing the result of evaluating the forward +# expression for ATen native function declarations. If the forward +# expression outputs a tuple, use 'resultX' instead to access the +# X-th entry +# +# - 'grad_input_mask', a std::array, specifies which input +# gradients are actually needed. For example, in the entry +# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size +# two array, where `grad_input_mask[0]` is true if `input0` requires +# grad, and `grad_input_mask[1]` is true if `input1` requires grad. +# +# (NB: if your function computes gradient for a list of tensors, +# the `grad_input_mask` will only have a single entry for the list +# specifying if either zero or at least one tensor from the list requires +# grad. If we want to support more fine-grained signalling, +# we'll need some alternate variable which is not a std::array) +# +# - 'retain_variables', a bool which is true if a user has specified +# that saved variables should be retained in case the backwards is +# run again later. This allows an optimization where we can +# destroy saved buffers if we know variables are not going to be retained, +# e.g., it is used by _cudnn_rnn +# +# - `wrap_opt_if`, is a 2-argument function that accepts a tensor +# variable and a boolean condition that dictates whether to save that +# variable in a graph. The result of this function is `std::optional`, +# and it is `::std::nullopt` when the condition evaluates to `false`, +# otherwise it is the variable wrapped in `std::optional`. +# For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2]) +# would mean that `var_0` is saved as long as the second (grad_input_mask[1]) +# or the third (grad_input_mask[2]) argument requires gradients. +# Another interpretation of this expression would read as `var_0` is needed +# in the backward computation of the second or the third argument. +# NOTE: the usage of `var_i.requires_grad()` in the conditional expression +# is not supported, use `grad_input_mask[i]` instead. +# NOTE: `wrap_opt_if` could be used to prevent saving redundant variables +# with multi-output backward formulas. +# See https://github.com/pytorch/pytorch/issues/97575 for more details +# on the issue. +# +# If you need a complex expression, e.g., with local variables, +# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp +# and invoke it from here. By the way, go read +# https://github.com/zdevito/ATen/issues/163; this describes an +# important hazard that occurs when porting backwards from Python to C++ +# +# Double backwards gradient expressions can be somewhat confusing; +# the most important thing to remember is: (1) you need to define a +# derivative formula for every input, including inputs named things +# like 'grad_output', and (2) the gradient to multiply with is always +# called 'grad' (even though it really is a grad-grad). +# +# You can also add forward derivative definition by defining a formula for +# a returned value (in general "result" if the name is not specified). This +# formula works the same way as the backward one and advanced implementations +# should also be placed in the FunctionsManual file. +# This formula should compute a single Jacobian vector product using the (primal) +# value of the argument "foo_p", its forward grad "foo_t" and the result of the +# function as "result". +# Note that the forward derivative can be automatically generated in two cases: +# - if your function is linear (NOT affine or multi-linear), then you can +# specify so by just using the string "auto_linear" for the formula. +# - if your function is applied element wise (and has a single input), you +# can specify so by just using the string "auto_element_wise" for the formula. +# +# Note that to avoid unpacking overhead, functions taking TensorList as inputs +# will always have their forward grad formula called. This function is responsible +# to check if any computation is needed and should return an undefined Tensor when +# there is nothing to do. You can check "cat_forward" for a full example. +# +# NB: There are a number of gradient definitions in here which are bogus +# (implemented using zeros_like). These gradients are (hopefully) not +# used by our frontend. You MUST check the frontend code; search for +# OpName.apply to see if it's still using a legacy Python style API. +# +# Note: Returning views. +# The following cases exist: +# - If a function returns no view, it can have arbitrary outputs. +# - If a function return at least one Tensor that is a differentiable view +# of one of its input: +# - If there is only one differentiable output, this Tensor is marked as a +# differentiable view. (alias or transpose for example) +# - If there are more than one differentiable output, by default all the views are +# marked as differentiable views and created with allow_rebase_history=false. +# Meaning that any inplace operation on it will raise an error. (unbind for example) +# +# Notes about undefined output gradients: +# All backward functions must support all combinations of undefined output +# gradient Tensors, where `grad[i].defined() == false`. Depending on the +# number of input and output grads your derivative formula uses, code +# generation may automatically add some level of undefined grad support, +# according to these three cases: +# +# * 1 input grad and 1 output grad: +# Complete undefined grad support is automatically added, so you +# shouldn't have to think about it, unless there is a bug in the code +# generation. +# +# * 1 input grad and multiple output grads: +# Undefined grad support is automatically added ONLY in the case where +# all output grads are undefined. You will have to add explicit support +# for cases where a subset of output grads is undefined. +# +# * multiple input grads: +# No automatic support, so you will need to add it. +# +# If your derivative formula uses more than one output grad, it is usually +# preferable to add undefined grad support in the backward function itself +# (if you're using one), rather than in the derivative formula in this file. +# +# Undefined Tensors are created with the default constructor `at::Tensor()`. +# It is an efficient way to represent a Tensor filled with zeros because +# the Tensor holds no sizing information and no Storage data is allocated. +# But consequently, Tensor operations cannot be performed on them. +# Therefore, your backward function should treat an undefined output grad as +# a zero, and it needs to be a special case. +# +# If all output grads are undefined, then it should be correct for the +# backward function to return undefined input grads. Since we use the chain +# rule, output grads equal to zero should result in input grads equal to zero, +# unless there is some rare special case. +# +# If a subset of output grads is undefined, then it may be acceptable for +# the backward function to return undefined input grads--it depends on the +# specific function, so you'll have to determine that yourself. If returning +# an undefined Tensor is correct for a given input grad, it is also logically +# correct to return a defined grad full of zeros, but that would not be +# preferable since it would be less efficient. +# +# NB: The parameter names here MUST be consistent with the parameter names +# in native_functions.yaml +- name: abs(Tensor self) -> Tensor + self: grad * self.sgn() + result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn()) + +- name: acos(Tensor self) -> Tensor + self: grad * -((-self * self + 1).rsqrt()).conj() + result: auto_element_wise + +- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj())) + result: self_t + maybe_multiply(other_t, alpha) + +- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + result: self_t.clone() + +- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha) + +- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) + result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value) + +- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj()) + tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj()) + result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value) + +- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + mat1: mm_mat1_backward(grad, mat2, mat1.sym_sizes(), mat1.sym_strides(), mat1.layout(), alpha) + mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha) + +- name: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta) + mat1: mm_mat1_sparse_backward(grad, mat1, mat2, alpha) + mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha) + +- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + mat: maybe_multiply(grad.ger(vec.conj()), alpha.conj()) + vec: maybe_multiply(mat.t().conj().mv(grad), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha) + +- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + vec1: maybe_multiply(grad.mv(vec2.conj()), alpha.conj()) + vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha) + +- name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + theta: affine_grid_generator_backward_symint(grad, size, align_corners) + result: auto_linear + +- name: alias(Tensor(a) self) -> Tensor(a) + self: grad + result: self_t + +- name: angle(Tensor self) -> Tensor + self: angle_backward(grad, self) + result: handle_r_to_c(result.scalar_type(), angle_backward(self_t.conj(), self_p).conj()) + +# The four items below are necessary because TensorIterator doesn't work on +# Variables (codegen does not unwrap the input Tensor for all() and any() ). +- name: any(Tensor self) -> Tensor + output_differentiability: [False] + +- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: _is_all_true(Tensor self) -> Tensor + self: non_differentiable + +- name: _is_any_true(Tensor self) -> Tensor + self: non_differentiable + +- name: all(Tensor self) -> Tensor + output_differentiability: [False] + +- name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + output_differentiability: [False] + +- name: acosh(Tensor self) -> Tensor +# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case) + self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()" + result: auto_element_wise + +- name: acosh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of acosh") + +- name: asinh(Tensor self) -> Tensor + self: grad * (self.pow(2) + 1).rsqrt().conj() + result: auto_element_wise + +- name: asinh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of asinh") + +- name: atanh(Tensor self) -> Tensor + self: grad * 1 / (1 - self.pow(2)).conj() + result: auto_element_wise + +- name: atanh_(Tensor(a!) self) -> Tensor(a!) + self: not_implemented("inplace version of atanh") + +- name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset) + result: auto_linear + +- name: asin(Tensor self) -> Tensor + self: grad * (-self * self + 1).rsqrt().conj() + result: auto_element_wise + +- name: atan(Tensor self) -> Tensor + self: grad / (self * self + 1).conj() + result: auto_element_wise + +- name: atan2(Tensor self, Tensor other) -> Tensor + self, other: atan2_backward(grad, self, other, grad_input_mask) + result: (-self_p * other_t + other_p * self_t) / (self_p.pow(2) + other_p.pow(2)) + +- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self: maybe_multiply(grad, beta.conj()) + batch1: maybe_multiply(grad.bmm(batch2.transpose(1, 2).conj()), alpha.conj()) + batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj()) + result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha) + +- name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + p: zeros_like(p) + result: self_t.zero_() + +- name: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: bmm(Tensor self, Tensor mat2) -> Tensor + self: grad.bmm(mat2.transpose(1, 2).conj()) + mat2: self.transpose(1, 2).conj().bmm(grad) + result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t) + +- name: matmul(Tensor self, Tensor other) -> Tensor + self, other: matmul_backward(grad, self, other, grad_input_mask) + +- name: cat(Tensor[] tensors, int dim=0) -> Tensor + tensors: cat_tensors_backward(grad, to_args_sizes_symint(tensors), to_args_scalartypes(tensors), dim) + result: cat_jvp(tensors, dim) + +- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: ceil(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: cholesky(Tensor self, bool upper=False) -> Tensor + self: cholesky_backward(grad, upper, result) + +- name: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + dispatch: + Default: + # the default case will use the CompositeImplicitAutograd + self: not_implemented("chunk") + AutogradNestedTensor: + self: chunk_backward_nested(grads, self, chunks, dim) + +- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + self: cholesky_backward(grad, upper, L) + L: cholesky_jvp(self_t, L, upper) + +- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask) + result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper) + +- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor + self: cholesky_inverse_backward(grad, self, upper, result) + result: cholesky_inverse_jvp(self_p, self_t, result, upper) + +# For clamp, gradient is not defined at the boundaries. But empirically it's helpful +# to be able to get gradient on min and max, so we return the subgradient 1 for these cases. +- name: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + self: clamp_backward(grad, self, min, max) + min, max: clamp_backward_min_max(grad, self, min, max, grad_input_mask) + result: clamp_jvp(self_p, self_t, min_p, min_t, max_p, max_t) + +- name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + self: clamp_backward(grad, self, min, max) + result: auto_element_wise + +- name: clamp_min(Tensor self, Scalar min) -> Tensor + self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) + result: auto_element_wise + +- name: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor + self: where(self >= min, grad, at::scalar_tensor(0., grad.options())) + min: where(self < min, grad, at::scalar_tensor(0., grad.options())) + result: where(self_p >= min_p, self_t, min_t) + +- name: clamp_max(Tensor self, Scalar max) -> Tensor + self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) + result: auto_element_wise + +- name: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor + self: where(self <= max, grad, at::scalar_tensor(0., grad.options())) + max: where(self > max, grad, at::scalar_tensor(0., grad.options())) + result: where(self_p <= max_p, self_t, max_t) + +- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + self: grad + result: auto_linear + +- name: _lazy_clone(Tensor self) -> Tensor + self: grad + result: auto_linear + +- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + self: _to_copy_backward(grad, self.options()) + result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format) + # The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype) + # (If dtype IS nullopt, we rely on the regular check that any input requires grad). + output_differentiability: ["!dtype || isDifferentiableType(*dtype)"] + +- name: _coalesce(Tensor self) -> Tensor + self: grad + +- name: complex(Tensor real, Tensor imag) -> Tensor + real: at::real(grad) + imag: at::imag(grad) + result: at::complex(real_t, imag_t) + +- name: polar(Tensor abs, Tensor angle) -> Tensor + abs, angle: polar_backward(grad, result) + result: at::complex(abs_t*angle_p.cos() - angle_t*abs_p*angle_p.sin(), abs_t*angle_p.sin() + angle_t*abs_p*angle_p.cos()) + +- name: _conj(Tensor(a) self) -> Tensor(a) + self: grad.conj() + result: self_t.conj() + +- name: _neg_view(Tensor(a) self) -> Tensor(a) + self: grad.neg() + result: self_t._neg_view() + +- name: _conj_physical(Tensor self) -> Tensor + self: grad.conj_physical() + result: self_t.conj_physical() + +- name: conj_physical_(Tensor(a!) self) -> Tensor(a!) + self: grad.conj_physical() + result: self_t.conj_physical_() + +- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + other: zeros_like(other) + result: copysign_tensor_self_backward(self_t, self_p, result) + +- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor + self: copysign_tensor_self_backward(grad, self, result) + result: auto_element_wise + +- name: cos(Tensor self) -> Tensor + self: grad * -self.sin().conj() + result: auto_element_wise + +- name: cosh(Tensor self) -> Tensor + self: grad * self.sinh().conj() + result: auto_element_wise + +- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor + output_differentiability: [False] + +- name: count_nonzero(Tensor self, int? dim=None) -> Tensor + output_differentiability: [False] + +- name: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor + self: at::linalg_cross(other.conj(), grad, dim) + other: at::linalg_cross(grad, self.conj(), dim) + result: "at::linalg_cross(self_t, other_p, dim) + at::linalg_cross(self_p, other_t, dim)" + +- name: logcumsumexp(Tensor self, int dim) -> Tensor + self: logcumsumexp_backward(grad, self, result, dim) + result: logcumsumexp_jvp(self_p, self_t, dim) + +- name: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + self: cumprod_backward(grad.to(self.scalar_type()), self, dim, result) + result: "cumprod_jvp(self_t, self_p, result, dim).to(dtype.has_value() ? *dtype : self_p.scalar_type())" + +- name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + self: cumsum_backward(grad.to(self.scalar_type()), dim) + result: auto_linear + +- name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummaxmin_backward(grad, self, indices, dim) + values: self_t.gather(dim, indices) + +- name: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummaxmin_backward(grad, self, indices, dim) + values: self_t.gather(dim, indices) + +- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + self, weight, bias: "grad.defined() ? conv_tbc_backward(grad, self, weight, bias, pad) : std::tuple()" + +- name: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) + +- name: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity) + +- name: deg2rad(Tensor self) -> Tensor + self: deg2rad_backward(grad) + result: auto_element_wise + +- name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + A: linalg_det_backward(grad, result, A, LU, pivots) + result: linalg_det_jvp(A_t, result, LU, pivots, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, False, False] + +- name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots) + sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex()) + output_differentiability: [True, True, False, False] + +- name: block_diag(Tensor[] tensors) -> Tensor + tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors)) + result: block_diag_jvp(tensors) + +- name: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + self: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2) + result: auto_linear + +- name: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + grad_output: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + self: norm_backward(grad, self - other, p, result) + other: -norm_backward(grad, self - other, p, result) + result: norm_jvp(self_p - other_p, self_t - other_t, p, result, {}, false) + +# The backward formula is done in this order to improve numerical stability +# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414 +# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later +- name: div.Tensor(Tensor self, Tensor other) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type()) + other: div_tensor_other_backward(grad, self, other) + result: (self_t - other_t * result) / other_p + +- name: div.Scalar(Tensor self, Scalar other) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type()) + result: self_t / other + +- name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) + other: div_tensor_other_backward(grad, self, other, rounding_mode) + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p" + +- name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode) + result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other" + +- name: dot(Tensor self, Tensor tensor) -> Tensor + self: grad * tensor.conj() + tensor: grad * self.conj() + result: at::dot(self_t, tensor_p) + at::dot(self_p, tensor_t) + +- name: vdot(Tensor self, Tensor other) -> Tensor + self: grad.conj() * other + other: grad * self + result: at::vdot(self_t, other_p) + at::vdot(self_p, other_t) + +- name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) + self: _fused_dropout_backward(grad, result1, p) + +- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) + input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))" + result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t" + +- name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor + grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)" + mask: 'not_implemented("native_dropout_backward: mask")' + +- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: erf(Tensor self) -> Tensor + self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise + +- name: erfc(Tensor self) -> Tensor + self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad + result: auto_element_wise + +- name: special_erfcx(Tensor self) -> Tensor + self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad + result: auto_element_wise + +- name: erfinv(Tensor self) -> Tensor + self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad + result: auto_element_wise + +- name: exp(Tensor self) -> Tensor + self: grad * result.conj() + result: auto_element_wise + +- name: exp2(Tensor self) -> Tensor + self: grad * result.conj() * M_LN2 + result: auto_element_wise + +- name: expm1(Tensor self) -> Tensor + self: grad * (result.conj() + 1) + result: auto_element_wise + +# TODO: this derivative is not SymInt safe, need sum_to support +- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + self: at::sum_to(grad, self.sym_sizes()) + result: auto_linear + +- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max, grad_factor) : std::tuple()" + +- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask) + +- name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor) : std::tuple()" + +- name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) + self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask) + +- name: fill.Scalar(Tensor self, Scalar value) -> Tensor + self: zeros_like(grad) + result: at::fill(self_t, 0) + +- name: fill.Tensor(Tensor self, Tensor value) -> Tensor + self: zeros_like(grad) + value: grad.sum() + result: at::fill(self_t, value_t) + +- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.fill_(0) + +- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) + self: zeros_like(grad) + value: grad.sum() + result: self_t.fill_(value_t) + +- name: floor(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: fmod.Scalar(Tensor self, Scalar other) -> Tensor + self: grad + result: auto_element_wise + +- name: fmod.Tensor(Tensor self, Tensor other) -> Tensor + self: grad + other: -grad * self.div(other, /*rounding_mode=*/"trunc") + result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"trunc") + +- name: frac(Tensor self) -> Tensor + self: grad + result: self_t + +- name: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) + self: grad / exponent.exp2() + mantissa: self_t / exponent.exp2() + +- name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + self: gather_backward(grad, self, dim, index, sparse_grad) + index: non_differentiable + result: auto_linear + +- name: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: geqrf(Tensor self) -> (Tensor a, Tensor tau) + self: not_implemented("geqrf") + +- name: indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: _indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: crow_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: col_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: ccol_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: row_indices(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +- name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" + +- name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" + +# See NOTE [ grid_sample CPU fallback ] +- name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple()" + +- name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: hardsigmoid(Tensor self) -> Tensor + self: hardsigmoid_backward(grad, self) + result: auto_element_wise + +- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + output_differentiability: [False] + +- name: hardswish(Tensor self) -> Tensor + self: hardswish_backward(grad, self) + result: auto_element_wise + +- name: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: hardswish_backward(grad, self) + self: at::where(at::logical_and(-3.0 < self, self < 3.0), grad * grad_output / 3.0, at::zeros({}, self.options())) + result: "hardswish_backward(grad_output_t, self_p) + + at::where(at::logical_and(-3.0 < self_p, self_p < 3.0), self_t * grad_output_p / 3.0, at::zeros({}, self_p.options()))" + +- name: hypot(Tensor self, Tensor other) -> Tensor + self: grad * self / result + other: grad * other / result + result: self_t * self_p / result + other_t * other_p / result + +- name: i0(Tensor self) -> Tensor + self: grad * at::special_i1(self) + result: auto_element_wise + +- name: special_i0e(Tensor self) -> Tensor + self: grad * (at::special_i1e(self) - self.sgn() * result) + result: auto_element_wise + +- name: special_i1(Tensor self) -> Tensor + self: i1_backward(grad, self, result) + result: auto_element_wise + +- name: special_i1e(Tensor self) -> Tensor + self: i1e_backward(grad, self, result) + result: auto_element_wise + +- name: igamma(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igamma: input")' + other: grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: igammac(Tensor self, Tensor other) -> Tensor + self: 'not_implemented("igammac: input")' + other: -grad * exp((self - 1) * log(other) - other - lgamma(self)) + +- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad) + result: auto_linear + +- name: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + self: at::_unsafe_index_put(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad, true) + result: auto_linear + +- name: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + self: at::_unsafe_masked_index_put_accumulate(grad.new_zeros_symint(self.sym_sizes(), self.options()), mask, indices, grad) + mask: non_differentiable + result: _unsafe_masked_index(self_t, mask, indices, 0) + +- name: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + self: grad + mask: non_differentiable + values: at::_unsafe_masked_index(grad, mask, indices, 0) + result: at::_unsafe_masked_index_put_accumulate(self_t, mask, indices, values_t) + +- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + self: grad + # The case source.dim() == 0 is necessary to support scalar tensors of the form + # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), + # This is because source is not broadcastable to index, as source.dim() < index.dim() + source: "maybe_multiply(source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0)), alpha)" + index: non_differentiable + result: at::index_add(self_t, dim, index, maybe_multiply(source_t, alpha)) + +- name: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + self, source: index_reduce_backward(grad, self, dim, index, source, reduce, include_self, result) + index: non_differentiable + +- name: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + self: grad.index_fill(dim, index, 0) + # The case source.dim() == 0 is necessary to support scalar tensors of the form + # source.dim() == 0 and index.dim() == 1 and index.size() == (1,), + # This is because source is not broadcastable to index, as source.dim() < index.dim() + source: "source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0))" + index: non_differentiable + result: self_t.index_copy(dim, index, source_t) + +- name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + self: grad.index_fill(dim, index, 0) + index: non_differentiable + result: self_t.index_fill(dim, index, 0) + +- name: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + self: grad.index_fill(dim, index, 0) + value: grad.index_select(dim, std::get<0>(at::_unique(index, /*sorted=*/false))).sum() + index: non_differentiable + result: self_t.index_fill(dim, index, value_t) + +- name: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" + values: grad.index(indices) + result: self_t.index_put(indices, values_t, accumulate) + +- name: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + self: "accumulate ? grad : at::_unsafe_index_put(grad, indices, zeros_like(values), false)" + values: at::_unsafe_index(grad, indices) + result: at::_unsafe_index_put(self_t, indices, values_t, accumulate) + +- name: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)" + values: grad.index(indices) + result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe) + +- name: index_select(Tensor self, int dim, Tensor index) -> Tensor + self: index_select_backward_symint(grad, self.sym_sizes(), dim, index) + index: non_differentiable + result: auto_linear + +- name: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) + A: -at::matmul(inverse.mH(), at::matmul(grad, inverse.mH())) + inverse: -at::matmul(at::matmul(inverse, A_t), inverse) + output_differentiability: [True, False] + +- name: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + self: pinv_backward(grad, result, self) + result: pinv_jvp(self_p, result, self_t) + +- name: isnan(Tensor self) -> Tensor + self: non_differentiable + +- name: kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + self: "weight.isComplex() ? grad * (1 - weight.conj().toComplexDouble()) : grad * (1 - weight.toDouble())" + end: grad * weight.conj() + result: at::lerp(self_t, end_t, weight) + +- name: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + self: grad * (1 - weight).conj() + end: grad * weight.conj() + weight: grad * (end - self).conj() + result: at::lerp(self_t, end_t, weight_p) + weight_t * (end_p - self_p) + +- name: lgamma(Tensor self) -> Tensor + self: grad * digamma(self) + result: auto_element_wise + +- name: digamma(Tensor self) -> Tensor + self: grad * polygamma(1, self) + result: auto_element_wise + +- name: polygamma(int n, Tensor self) -> Tensor + self: grad * polygamma(n + 1, self) + result: auto_element_wise + +- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) + self: grad * polygamma(n + 1, self) + result: self_t.mul_(polygamma(n + 1, original_self_p)) + +- name: log(Tensor self) -> Tensor + self: grad.div(self.conj()) + result: auto_element_wise + +- name: log10(Tensor self) -> Tensor + self: grad / (self.conj() * 2.3025850929940456) + result: auto_element_wise + +- name: log1p(Tensor self) -> Tensor + self: log1p_backward(grad, self) + result: auto_element_wise + +- name: log2(Tensor self) -> Tensor + self: grad / (self.conj() * 0.6931471805599453) + result: auto_element_wise + +- name: logaddexp(Tensor self, Tensor other) -> Tensor + self: grad / (1 + exp(other - self)).conj() + other: grad / (1 + exp(self - other)).conj() + result: self_t / (1 + exp(other_p - self_p)) + other_t / (1 + exp(self_p - other_p)) + +- name: logaddexp2(Tensor self, Tensor other) -> Tensor + self: grad / (1 + pow(2, other - self)) + other: grad / (1 + pow(2, self - other)) + result: self_t / (1 + pow(2, other_p - self_p)) + other_t / (1 + pow(2, self_p - other_p)) + +# Note [Gradient formula for xlogy at x = 0, y <= 0] +# x * log(y) is not defined at y <= 0, so we cannot even talk about differentiability +# Now, xlogy(0, y) = 0 by definition. +# This does not make it differentiable as it's not defined in a neighbourhood of a point +# (0, y) when y <= 0. +# Now, when a function is non-differentiable, sometimes we return "a relatively sensible value" +# In this case, as per the discussion in https://github.com/pytorch/pytorch/issues/80770, we choose +# this value to be zero, which is the directional derivative along the line {x = 0}. +- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + self: at::xlogy(grad, other).masked_fill((self == 0.) & (other <= 0.), 0.) + other: grad * self / other + result: at::xlogy(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= 0.), 0.) + other_t * self_p / other_p + +- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + other: grad * self / other + result: auto_element_wise + +- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + self: "other.toDouble() > 0. + ? at::xlogy(grad, other) + : at::xlogy(grad, other).masked_fill(self == 0., 0.)" + result: auto_element_wise + +# See Note [Gradient formula for xlogy at x = 0, y <= 0] +# Same here but with y <= -1 +- name: special_xlog1py(Tensor self, Tensor other) -> Tensor + self: at::special_xlog1py(grad, other).masked_fill((self == 0.) & (other <= -1.), 0.) + other: grad * self / (other + 1) + result: at::special_xlog1py(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= -1.), 0.) + other_t * self_p / (other_p + 1) + +- name: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor + other: grad * self / (other + 1) + result: auto_element_wise + +- name: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor + self: "other.toDouble() > -1. + ? at::special_xlog1py(grad, other) + : at::special_xlog1py(grad, other).masked_fill(self == 0., 0.)" + result: auto_element_wise + +- name: special_zeta(Tensor self, Tensor other) -> Tensor + self: not_implemented("zeta") + other: grad * -self * special_zeta(self + 1., other) + +- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor + other: grad * -self * special_zeta(self.toDouble() + 1., other) + +- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor + self: not_implemented("zeta") + +- name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + self: logsumexp_backward(grad, self, result, dim, keepdim) + result: logsumexp_jvp(self_p, self_t, dim, keepdim) + +- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) + self, b: linalg_lstsq_backward(grads[0], grads[1], self, b, solution, grad_input_mask) + solution: linalg_lstsq_solution_jvp(self_p, b_p, self_t, b_t) + residuals: linalg_lstsq_residuals_jvp(self_p, b_p, self_t, b_t, solution, residuals) + output_differentiability: [True, True, False, False] + +- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + A: lu_factor_ex_backward(grad, LU, pivots, pivot) + LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot) + output_differentiability: [True, False, False] + +- name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) + A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot) + L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot)) + U: std::get<1>(linalg_lu_jvp(A_t, P, L, U, pivot)) + output_differentiability: [False, True, True] + +- name: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + LU: linalg_lu_solve_LU(grad, LU, pivots, result, left, adjoint) + B: "at::linalg_lu_solve(LU, pivots, grad, left, !adjoint)" + result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint) + +- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) + LU_pivots: non_differentiable + L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril_symint(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril_symint(-1)" + U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu_symint() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu_symint()" + output_differentiability: [False, True, True] + +- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + self: grad.masked_fill(mask, 0) + mask: non_differentiable + result: self_t.masked_fill(mask, 0) + +- name: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + self: grad.masked_fill(mask, 0) + value: masked_fill_backward(grad, mask) + mask: non_differentiable + result: self_t.masked_fill(mask, value_t) + +- name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + self: grad.masked_fill(mask, 0) + source: masked_scatter_backward_symint(grad, mask, source.sym_sizes()) + mask: non_differentiable + result: self_t.masked_scatter(mask, source_t) + +- name: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + grad_output: zeros_like(grad_output).masked_scatter(mask, grad) + mask: non_differentiable + result: masked_scatter_backward(grad_output_t, mask, grad_output_t.sizes()) + +- name: masked_select(Tensor self, Tensor mask) -> Tensor + self: masked_select_backward(grad, self, mask) + mask: non_differentiable + result: auto_linear + +- name: linalg_matrix_exp(Tensor self) -> Tensor + self: linalg_matrix_exp_differential(self, grad, /*adjoint*/ true) + result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false) + +- name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: max(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: maximum(Tensor self, Tensor other) -> Tensor + self: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) + other: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) + result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p > other_p).to(result.scalar_type())) * (self_t - other_t) + +- name: fmax(Tensor self, Tensor other) -> Tensor + self: grad.masked_fill((self >= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.masked_fill((self >= other).logical_or_(other.isnan()), 0) + result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t) + +- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor + dispatch: + Default: + self: grad.expand_symint(self.sym_sizes()) / self.sym_numel() + result: auto_linear + AutogradNestedTensor: + # TODO: replace this with grad.expand_as(self) / self.sym_numel() when that is supported + self: (ones_like(self) * grad) / self.sym_numel() + result: auto_linear + +- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: mean_backward(grad, self.sym_sizes(), dim, self.sym_numel(), keepdim) + result: auto_linear + +- name: median(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: nanmedian(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +# This is in theory incorrect in the following case: +# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value +# | at middle position of the +# | list between two `b`s. E.g., +# | +# ^the middle position +# The gradient exists and is essentially 0 in this case. +# +# In case where the middle position is at the boundary of `b` range, e.g., +# sorted list: [..., a, b, b, ..., b, b, c, ...] +# | +# ^the middle position +# The backward implementation is correct in the sense that it returns the +# subgradient on one side. +- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: min(Tensor self) -> Tensor + self: evenly_distribute_backward(grad, self, result) + result: evenly_read_jvp(self_t, self_p, result) + +- name: minimum(Tensor self, Tensor other) -> Tensor + self: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0) + other: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0) + result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p < other_p).to(result.scalar_type())) * (self_t - other_t) + +- name: fmin(Tensor self, Tensor other) -> Tensor + self: grad.masked_fill((self <= other).logical_or_(other.isnan()).logical_not_(), 0) + other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0) + result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t) + +- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) + +- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) + result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) + +- name: mm(Tensor self, Tensor mat2) -> Tensor + self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1) + mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) + result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) + +- name: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + self: _grouped_mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), offs, 1) + mat2: _grouped_mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), offs, 1) + +- name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) + values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) + +- name: mul.Tensor(Tensor self, Tensor other) -> Tensor + self: mul_tensor_backward(grad, other, self.scalar_type()) + other: mul_tensor_backward(grad, self, other.scalar_type()) + result: other_t * self_p + self_t * other_p + +- name: mul.Scalar(Tensor self, Scalar other) -> Tensor + self: mul_tensor_backward(grad, other, self.scalar_type()) + result: self_t * other + +- name: mv(Tensor self, Tensor vec) -> Tensor + self: grad.ger(vec.conj()) + vec: self.conj().t().mv(grad) + result: mv(self_t, vec_p) + mv(self_p, vec_t) + +- name: mvlgamma(Tensor self, int p) -> Tensor + self: mvlgamma_backward(grad, self, p) + result: auto_element_wise + +- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + self: grad * at::isfinite(self) + result: auto_element_wise + +- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps) + +- name: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*training=*/false, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, /*training=*/false, eps) + +- name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps) + +- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask) + save_mean: not_implemented("native_batch_norm_backward save_mean") + save_invstd: not_implemented("native_batch_norm_backward save_invstd") + +- name: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? native_layer_norm_backward_symint(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple()" + result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape) + +- name: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask) + bias: Tensor() + mean: not_implemented("native_layer_norm_backward mean") + rstd: not_implemented("native_layer_norm_backward rstd") + +- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" + result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) + result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) + +- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" + result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) + result1: group_norm_mean_jvp(input_t, result1, group) + result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group) + +- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + self: zeros_like(self) + result: self_t.zero_() + +- name: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + self: zeros_like(self) + other: zeros_like(other) + result: self_t.zero_() + +- name: neg(Tensor self) -> Tensor + self: grad.neg() + result: auto_element_wise + +- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps) + +- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps) + +- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) + input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask) + save_mean: not_implemented("batch_norm_backward save_mean") + save_var: not_implemented("batch_norm_backward save_var") + reserve: not_implemented("batch_norm_backward reserve") + +- name: nextafter(Tensor self, Tensor other) -> Tensor + self: not_implemented("nextafter") + other: not_implemented("nextafter") + +- name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor + self: norm_backward(grad, self, p, result) + result: norm_jvp(self_p, self_t, p, result) + +- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + self: norm_backward(grad, self, p, result, dim, keepdim) + result: norm_jvp(self_p, self_t, p, result, dim, keepdim) + +- name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + self: norm_backward(grad, self.to(grad.scalar_type()), p, result) + result: norm_jvp(self_p, self_t, p, result) + +- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim) + result: norm_jvp(self_p, self_t, p, result, dim, keepdim) + +- name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim) + result: linalg_vector_norm_jvp(self_p, self_t, ord, result, dim, keepdim) + +- name: _pdist_forward(Tensor self, float p=2) -> Tensor + self: _pdist_backward(grad, self, p, result) + +- name: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + grad: not_implemented("_pdist_backward") + self: not_implemented("_pdist_backward") + pdist: not_implemented("_pdist_backward") + +- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor + x1, x2: _euclidean_dist_backward(grad, x1, x2, result) + +- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor + x1: _cdist_backward(grad.contiguous(), x1, x2, p, result) + x2: _cdist_backward(grad.mT().contiguous(), x2, x1, p, result.mT().contiguous()) + +- name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + grad: not_implemented("_cdist_backward") + x1: not_implemented("_cdist_backward") + x2: not_implemented("_cdist_backward") + cdist: not_implemented("_cdist_backward") + +- name: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + result: auto_element_wise + +- name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor + std: at::zeros_symint(std.sym_sizes(), grad.options()) + result: auto_element_wise + +- name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor + mean: at::zeros_symint(mean.sym_sizes(), grad.options()) + std: at::zeros_symint(std.sym_sizes(), grad.options()) + result: zeros_like(mean_t) + +- name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor + input, tau: householder_product_backward(grad, result, input, tau) + result: householder_product_jvp(input_t, tau_t, result, input_p, tau_p) + +- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + self, input2, input3: ormqr_backward(grad, result, self, input2, input3, left, transpose, grad_input_mask) + +- name: permute(Tensor(a) self, int[] dims) -> Tensor(a) + self: permute_backwards(grad, dims) + result: auto_linear + +- name: poisson(Tensor self, Generator? generator=None) -> Tensor + self: zeros_like(self) + result: auto_element_wise + +- name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + self: pow_backward(grad, self, exponent) + result: auto_element_wise + +- name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + self: pow_backward_self(grad, self, exponent) + exponent: pow_backward_exponent(grad, self, exponent, result) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj() + +- name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor + exponent: pow_backward_exponent(grad, self, exponent, result) + result: auto_element_wise + +- name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor + self: prod_backward(grad, self.to(grad.scalar_type()), result) + result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result) * self_t.conj()).sum().conj() + +- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) + result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj() + +- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + self: "accumulate ? grad : grad.put(index, zeros_like(source), false)" + index: non_differentiable + source: grad.take(index).reshape_as(source) + result: self_t.put(index, source_t, accumulate) + +- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode) + Q, R: linalg_qr_jvp(A_t, Q, R, mode) + +- name: rad2deg(Tensor self) -> Tensor + self: rad2deg_backward(grad) + result: auto_element_wise + +- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: reciprocal(Tensor self) -> Tensor + self: -grad * (result * result).conj() + result: auto_element_wise + +- name: remainder.Scalar(Tensor self, Scalar other) -> Tensor + self: grad + result: auto_element_wise + +- name: remainder.Tensor(Tensor self, Tensor other) -> Tensor + self: grad + other: -grad * self.div(other, /*rounding_mode=*/"floor") + result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"floor") + +- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + self: renorm_backward(grad, self, p, dim, maxnorm) + result: renorm_jvp(self_p, self_t, p, dim, maxnorm) + +- name: repeat(Tensor self, SymInt[] repeats) -> Tensor + self: repeat_backward(grad, repeats, self.sym_sizes()) + result: auto_linear + +- name: special_entr(Tensor self) -> Tensor + self: grad * (-(1 + self.log())) + result: auto_element_wise + +- name: special_ndtri(Tensor self) -> Tensor + self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp() + result: auto_element_wise + +- name: special_log_ndtr(Tensor self) -> Tensor + self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp() + result: auto_element_wise + +# [Note: Sometimes view derivatives] +# The following situation applies to other operations as well. +# TODO: This note is only referenced by to_dense and to_sparse*. Make +# this more generic if it's been referenced more than once. +# +# DO NOT define a backward for reshape! +# reshape is special in that it sometimes returns a view, and sometimes not. +# Defining a backward will make codegen spit out the forward call as +# as_variable(baseType->reshape(self)), +# making it impossible (hard) to detect when it is actually a view. +# - name: reshape(Tensor self, IntArrayRef shape) + +- name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +- name: round(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: round.decimals(Tensor self, *, int decimals) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: rsqrt(Tensor self) -> Tensor + self: -0.5 * grad * result.pow(3).conj() + result: auto_element_wise + +- name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + self: grad.scatter(dim, index, 0) + index: non_differentiable + src: grad.gather(dim, index) + result: self_t.scatter(dim, index, src_t) + +- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + self: grad.scatter(dim, index, 0) + index: non_differentiable + result: self_t.scatter(dim, index, 0) + +- name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + self: grad + index: non_differentiable + src: grad.gather(dim, index) + result: scatter_add(self_t, dim, index, src_t) + +- name: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + dispatch: + Default: + self: select_backward_symint(grad, self.sym_sizes(), dim, index) + result: auto_linear + AutogradNestedTensor: + self: _nested_select_backward_symint(grad, self, dim, index) + +- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + grad_output: grad.select_symint(dim, index) + result: auto_linear + +- name: sigmoid(Tensor self) -> Tensor + self: sigmoid_backward(grad, result) + result: auto_element_wise + +- name: logit(Tensor self, float? eps=None) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_logit_backward(grad, self, eps) : logit_backward(grad, self, eps)" + result: auto_element_wise + +- name: sign(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: sgn(Tensor self) -> Tensor + self: sgn_backward(self, grad, result) + # Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric) + # The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian + # auto_element_wise has a name that's a bit deceiving in the complex case + result: sgn_backward(self_p, self_t, result) + +- name: sin(Tensor self) -> Tensor + self: grad * self.cos().conj() + result: auto_element_wise + +- name: sinc(Tensor self) -> Tensor + self: sinc_backward(grad, self) + result: auto_element_wise + +- name: sinh(Tensor self) -> Tensor + self: grad * self.cosh().conj() + result: auto_element_wise + +- name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step) + result: auto_linear + +- name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + grad_output: grad.slice_symint(dim, start, end, step) + result: auto_linear + +- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + self: grad.slice_symint(dim, start, end, step) + src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step) + result: auto_linear + +- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step) + src: grad.slice_symint(dim, start, end, step) + result: auto_linear + +- name: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + self: select_scatter_symint(grad, zeros_like(src), dim, index) + src: grad.select_symint(dim, index) + result: auto_linear + +- name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor + self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2) + src: grad.diagonal(offset, dim1, dim2) + result: auto_linear + +- name: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset) + # See Note [as_strided_scatter backward support] + src: grad.contiguous().as_strided_symint(size, stride, storage_offset) + result: auto_linear + +- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1]) + result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())" + output_differentiability: [True, False, False, False] # LU is an auxiliary tensor not exposed to the user + +- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) + +- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather_with_keepdimed_indices(self_t, dim, indices, true) + +- name: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + dispatch: + Default: + self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) + result: auto_linear + AutogradNestedTensor: + self: _nested_split_with_sizes_backward(grads, split_sizes, dim, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), self.options()) + +- name: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options()) + result: auto_linear + +- name: sqrt(Tensor self) -> Tensor + self: grad / (2 * result.conj()) + result: auto_element_wise + +- name: squeeze(Tensor(a) self) -> Tensor(a) + self: unsqueeze_to(grad, self.sym_sizes()) + result: auto_linear + +- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + dispatch: + Default: + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.unsqueeze(dim) + +- name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + dispatch: + Default: + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: unsqueeze_multiple(grad, dim, self.dim()) + +- name: squeeze_(Tensor(a!) self) -> Tensor(a!) + self: unsqueeze_to(grad, self.sym_sizes()) + result: auto_linear + +- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + +- name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) + self: unsqueeze_to(grad, dim, self.sym_sizes()) + result: auto_linear + +- name: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + self: std_backward(result, grad, self, dim, correction, keepdim) + # pointwise (variance) + sum + sqrt + result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0) + +- name: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim) + result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) + +- name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + other: handle_r_to_c(other.scalar_type(), maybe_multiply(-grad, alpha.conj())) + result: self_t - maybe_multiply(other_t, alpha) + +- name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), grad) + result: auto_element_wise + +- name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) + other: handle_r_to_c(other.scalar_type(), grad) + result: -maybe_multiply(self_t, alpha) + other_t + +- name: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj())) + result: auto_element_wise + +- name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + dispatch: + Default: + self: grad.expand_symint(self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + # TODO: replace this with grad.expand_as(self) when that is supported + self: ones_like(self) * grad + result: auto_linear + +- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + dispatch: + Default: + self: sum_backward(grad, self.sym_sizes(), dim, keepdim) + result: auto_linear + AutogradNestedTensor: + # TODO: replace this function once semantics for nested tensor expand have been settled on + self: _nested_sum_backward(grad, self, dim, keepdim) + +- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim) + result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype) + +# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here +- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U, + grad_S, + full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh, + full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U, + S, + full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)" + U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices) + +- name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) + A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true) + eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true) + +- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) + self: handle_r_to_c(self.scalar_type(), linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/false)) + eigenvalues, eigenvectors: linalg_eig_jvp(self_t, eigenvalues, eigenvectors, /*is_hermitian=*/false) + +- name: t(Tensor(a) self) -> Tensor(a) + self: grad.t() + result: auto_linear + +- name: t_(Tensor(a!) self) -> Tensor(a!) + self: grad.t() + result: auto_linear + +- name: one_hot(Tensor self, int num_classes=-1) -> Tensor + self: non_differentiable + +- name: flip(Tensor self, int[] dims) -> Tensor + self: grad.flip(dims) + result: auto_linear + +- name: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + self: grad.roll_symint(fmap(reverse_list_symint(shifts), [](c10::SymInt i){return -i;}), reverse_list(dims)) + result: auto_linear + +- name: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + self: grad.rot90(-k, dims) + result: auto_linear + +- name: take(Tensor self, Tensor index) -> Tensor + self: take_backward(grad, self, index) + index: non_differentiable + result: auto_linear + +- name: tan(Tensor self) -> Tensor + self: grad * (1 + result.pow(2)).conj() + result: auto_element_wise + +- name: tanh(Tensor self) -> Tensor + self: tanh_backward(grad, result) + result: auto_element_wise + +- name: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true) + output_differentiability: [True, False] + values: gather(self_t, dim, indices) + +- name: trace(Tensor self) -> Tensor + self: trace_backward_symint(grad, self.sym_sizes()) + result: auto_linear + +- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + self: grad.transpose(dim0, dim1) + result: auto_linear + +- name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + self: grad.transpose(dim0, dim1) + result: auto_linear + +- name: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + self, A: triangular_solve_backward(grad_solution, grad_cloned_coefficient, self, A, solution, upper, transpose, unitriangular, grad_input_mask) + solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular) + cloned_coefficient: A_t + +- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor + self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) + result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) + +- name: tril(Tensor self, SymInt diagonal=0) -> Tensor + self: grad.tril_symint(diagonal) + result: auto_linear + +- name: triu(Tensor self, SymInt diagonal=0) -> Tensor + self: grad.triu_symint(diagonal) + result: auto_linear + +- name: trunc(Tensor self) -> Tensor + self: zeros_like(grad) + result: auto_element_wise + +- name: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor + output_differentiability: [False] + +# DO NOT define a backward for to_dense +# See [Note: Sometimes view derivatives] +# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor +# +- name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor + self: to_dense_backward(grad, self, masked_grad) + +# DO NOT define a backward for to_sparse.sparse_dim +# See [Note: Sometimes view derivatives] +# - name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor +# +- name: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse +# See [Note: Sometimes view derivatives] +# - name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor +# +- name: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_csr +# See [Note: Sometimes view derivatives] +# - name: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_csc +# See [Note: Sometimes view derivatives] +# - name: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_bsr +# See [Note: Sometimes view derivatives] +# - name: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +# DO NOT define a backward for to_sparse_bsc +# See [Note: Sometimes view derivatives] +# - name: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor +# +- name: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + self: to_sparse_backward(grad, self.layout(), self.sym_blocksize()) + +- name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor + self: to_mkldnn_backward(grad, self) + +- name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step) + result: auto_linear + +- name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + grad_in: grad.unfold(dim, size, step) + result: auto_linear + +- name: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) + self: zeros_like(grad) + result: self_t.zero_() + +- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + self: not_implemented("_unique") + +- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim") + +- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_consecutive") + +- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("unique_dim_consecutive") + +- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, False, False] + self: not_implemented("_unique2") + +- name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +- name: lift(Tensor self) -> Tensor + self: grad + result: auto_linear + +- name: lift_fresh(Tensor(a) self) -> Tensor(a) + self: grad + result: auto_linear + +- name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + self: grad.squeeze(dim) + result: auto_linear + +- name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + self: grad.squeeze(dim) + result: auto_linear + +- name: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + self: var_backward(grad, self, dim, correction, keepdim) + # pointwise + sum + result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + +- name: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim) + result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) + # linear + result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim) + +- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a) + dispatch: + Default: + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + AutogradNestedTensor: + self: grad.reshape_as(self) + result: auto_linear + +- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + output_differentiability: [False] + +- name: view_as_real(Tensor(a) self) -> Tensor(a) + self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1 + result: at::view_as_real(self_t) + +- name: view_as_complex(Tensor(a) self) -> Tensor(a) + self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy] + result: at::view_as_complex(self_t) + +- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + condition: non_differentiable + self: where(condition, grad, 0) + other: where(condition, 0, grad) + result: where(condition, self_t, other_t) + +# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen +# to be running backward with create_graph=True, fall back to a backward function that uses +# differentiable ops. +- name: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + v, g: "grad.defined() ? (GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_interface_backward(grad.contiguous(), v, g, result1, dim)) : std::tuple()" + +- name: zero_(Tensor(a!) self) -> Tensor(a!) + self: zeros_like(grad) + result: auto_linear + +- name: sparse_mask(Tensor self, Tensor mask) -> Tensor + self: sparse_mask_backward(grad, mask, self.layout()) + mask: non_differentiable + +- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + indices: non_differentiable + values: grad.sparse_mask(result)._values() + +- name: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + compressed_indices: non_differentiable + plain_indices: non_differentiable + # TODO: remove to_dense after gh-107381 is fixed + values: grad.to_dense().sparse_mask(result).values() + +- name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + self: at::_sparse_sum_backward(grad, self, dim) + +- name: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor + self: grad * _standard_gamma_grad(self, result) + +- name: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor + self: not_implemented("_standard_gamma_grad") + +- name: values(Tensor(a) self) -> Tensor(a) + dispatch: + Default: + self: values_backward(grad, self) + AutogradNestedTensor: + self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_storage_offsets()) + +# Why is _values() not differentiable? +# See NOTE [ Sparse: autograd and API ] +- name: _values(Tensor(a) self) -> Tensor(a) + output_differentiability: [False] + +# NN +- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + i1, i2, i3: "_trilinear_backward(grad, + wrap_opt_if(i1, grad_input_mask[1] || grad_input_mask[2]), + wrap_opt_if(i2, grad_input_mask[0] || grad_input_mask[2]), + wrap_opt_if(i3, grad_input_mask[0] || grad_input_mask[1]), + expand1, expand2, expand3, sumdim, grad_input_mask)" + result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + + _trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) + + _trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)" + +- name: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + self: constant_pad_nd_backward(grad, pad) + result: constant_pad_nd_symint(self_t, pad, 0) + +- name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_backward(grad, self, target, weight, reduction) + target: binary_cross_entropy_target_backward(grad, self, target, weight, reduction) + result: "apply_loss_reduction( + binary_cross_entropy_backward(self_t, self_p, target_p, weight, at::Reduction::None) + + binary_cross_entropy_target_backward(target_t, self_p, target_p, weight, at::Reduction::None), + reduction)" + +- name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction) + target: binary_cross_entropy_double_backward_target(grad, grad_output, self, target, weight, reduction) + grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction) + result: " binary_cross_entropy_double_backward(grad_output_p, self_t, self_p, target_p, weight, reduction) + + binary_cross_entropy_double_backward_target(target_t, grad_output_p, self_p, target_p, weight, reduction) + + binary_cross_entropy_double_backward_grad_output(grad_output_t, self_p, target_p, weight, reduction)" + +- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor + self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction) + target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction) + result: "apply_loss_reduction( + binary_cross_entropy_with_logits_backward(self_t, self_p, target_p, weight, pos_weight, at::Reduction::None) + + binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None), + reduction)" + +- name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + indices: non_differentiable + weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse) + result: auto_linear + +- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + grad_output: embedding_dense_double_backward_symint(grad, indices, padding_idx) + indices: non_differentiable + result: auto_linear + +- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + indices: non_differentiable + offsets: non_differentiable + weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx) + per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx) + +- name: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + grad: not_implemented("_embedding_bag_backward") + indices: non_differentiable + offsets: non_differentiable + offset2bag: non_differentiable + bag_size: non_differentiable + maximum_indices: non_differentiable + per_sample_weights: not_implemented("_embedding_bag_backward") + +- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + grad: not_implemented("_embedding_bag_dense_backward") + indices: non_differentiable + offset2bag: non_differentiable + bag_size: non_differentiable + maximum_indices: non_differentiable + per_sample_weights: not_implemented("_embedding_bag_dense_backward") + +- name: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + indices: non_differentiable + self: not_implemented("embedding_renorm") + +- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + self: mse_loss_backward(grad, self, target, reduction) + target: mse_loss_backward(grad, target, self, reduction) + result: apply_loss_reduction(mse_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None).conj() + mse_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None).conj(), reduction) + +- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction) + target: non_differentiable + +- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target) + target: non_differentiable + +- name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + self: nll_loss_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) + target: non_differentiable + output: std::get<0>(nll_loss_forward_symint(self_t, target, weight, reduction, ignore_index)) + +- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + self: nll_loss2d_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight) + target: non_differentiable + output: std::get<0>(nll_loss2d_forward_symint(self_t, target, weight, reduction, ignore_index)) + +- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + self: smooth_l1_loss_backward(grad, self, target, reduction, beta) + target: smooth_l1_loss_backward(grad, target, self, reduction, beta) + result: apply_loss_reduction(smooth_l1_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, beta).conj() + smooth_l1_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, beta).conj(), reduction) + +- name: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor + self: huber_loss_backward(grad, self, target, reduction, delta) + target: huber_loss_backward(grad, target, self, reduction, delta) + result: apply_loss_reduction(huber_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, delta).conj() + huber_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, delta).conj(), reduction) + +- name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + self: soft_margin_loss_backward(grad, self, target, reduction) + result: apply_loss_reduction(soft_margin_loss_backward(self_t.conj(), self_p, target, at::Reduction::None).conj(), reduction) + +- name: relu(Tensor self) -> Tensor + self: threshold_backward(grad, result, 0) + result: auto_element_wise + +- name: silu(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)" + result: auto_element_wise + +- name: mish(Tensor self) -> Tensor + self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)" + result: auto_element_wise + +- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self) + result: auto_element_wise + +- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result) + result: self_t.copy_(elu_backward(original_self_t, alpha, scale, input_scale, /* is_result */ true, result)) + +- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self) + result: auto_element_wise + +- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result) + result: self_t.copy_(elu_backward(original_self_t, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)) + +- name: gelu(Tensor self, *, str approximate='none') -> Tensor + self: gelu_backward(grad, self, approximate) + result: auto_element_wise + +- name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + grad_output: gelu_backward(grad, self, approximate) + self: gelu_double_backward(grad, grad_output, self, approximate) + result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate) + +- name: glu(Tensor self, int dim=-1) -> Tensor + # TODO: glu_backward can benefit from forward result, + # and forward ad/forward over reverse ad for that matter + self: glu_backward(grad, self, dim) + result: glu_jvp(result, self_p, self_t, dim) + +- name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + self: hardshrink_backward(grad, self, lambd) + result: auto_element_wise + +- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + grad_out: hardshrink_backward(grad, self, lambd) + self: zeros_like(grad) + result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result)) + +- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + self: hardtanh_backward(grad, self, min_val, max_val) + result: auto_element_wise + +- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + self: leaky_relu_backward(grad, self, negative_slope, false) + result: auto_element_wise + +- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + self: leaky_relu_backward(grad, result, negative_slope, true) + result: self_t.copy_(leaky_relu_backward(original_self_t.conj(), result, negative_slope, true).conj()) + +- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + self: log_sigmoid_backward(grad, self, buffer) + output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj() + output_differentiability: [True, False] + +- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _log_softmax_backward_data(grad, result, dim, self.scalar_type()) + result: self_t - logsumexp_jvp(self_p, self_t, {dim}, true) + +- name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _sparse_log_softmax_backward_data(grad, result, dim, self) + +- name: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + self: _masked_softmax_backward(grad, result, mask, dim) + mask: non_differentiable + +- name: _prelu_kernel(Tensor self, Tensor weight) -> Tensor + self, weight: "grad.defined() ? _prelu_kernel_backward(grad, self, weight) : std::tuple()" + result: at::where(self_p >= 0, self_t, weight_p * self_t + weight_t * self_p) + +- name: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + grad_output: "grads[0].defined() ? + (grads[1].defined() ? at::where(self >= 0, grads[0], grads[0] * weight + grads[1] * self) + : at::where(self >= 0, grads[0], grads[0] * weight)) + : at::where(self >= 0, at::zeros({}, grad_output.options()), grads[1] * self)" + self: "grads[1].defined() ? at::where(self >= 0, at::zeros({}, self.options()), grad_output * grads[1]) : zeros_like(self)" + weight: "grads[0].defined() ? at::where(self >= 0, at::zeros({}, weight.options()), grad_output * grads[0]) : zeros_like(self)" + result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t) + result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p) + +- name: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + result: auto_element_wise + +- name: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true) + +- name: rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out) + noise: non_differentiable + self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + +- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true)) + +- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + self: _sparse_softmax_backward_data(grad, result, dim, self) + +- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + self: sparse_sparse_matmul_backward(grad, self, other, 0) + other: sparse_sparse_matmul_backward(grad, self, other, 1) + +- name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + self: softplus_backward(grad, self, beta, threshold) + result: auto_element_wise + +- name: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + self: softshrink_backward(grad, self, lambd) + result: auto_element_wise + +- name: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + self: threshold_backward(grad, self, threshold) + result: auto_element_wise + +- name: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + self: threshold_backward(grad, self, threshold) + result: self_t.copy_(threshold_backward(self_t.conj(), original_self_p, threshold).conj()) + +- name: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + self: reflection_pad1d_backward_symint(grad, self, padding) + result: auto_linear + +- name: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + self: reflection_pad2d_backward_symint(grad, self, padding) + result: auto_linear + +- name: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + self: reflection_pad3d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + self: replication_pad1d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + self: replication_pad2d_backward_symint(grad, self, padding) + result: auto_linear + +- name: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + self: replication_pad3d_backward_symint(grad, self, padding) + result: auto_linear + +- name: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + self: upsample_linear1d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales) + result: auto_linear + +- name: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_bilinear2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_bilinear2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_bicubic2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + self: upsample_nearest1d_backward_symint(grad, output_size, self.sym_sizes(), scales) + result: auto_linear + +- name: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + self: _upsample_nearest_exact1d_backward_symint(grad, output_size, self.sym_sizes(), scales) + result: auto_linear + +- name: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_nearest2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_nearest_exact2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: upsample_nearest3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w) + result: auto_linear + +- name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + self: pixel_unshuffle(grad, upscale_factor) + result: auto_linear + +- name: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + self: pixel_shuffle(grad, downscale_factor) + result: auto_linear + +- name: channel_shuffle(Tensor self, SymInt groups) -> Tensor + self: channel_shuffle_symint(grad, grad.sym_size(1) / groups) + result: auto_linear + +- name: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + self: _adaptive_avg_pool2d_backward(grad, self) + result: auto_linear + +- name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + self: _adaptive_avg_pool3d_backward(grad, self) + result: auto_linear + +- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + self: adaptive_max_pool2d_backward(grad, self, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + self: adaptive_max_pool3d_backward(grad, self, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + result: auto_linear + +- name: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + result: auto_linear + +- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + input, weight, bias: "grad.defined() ? linear_backward(input, grad, weight, grad_input_mask) : std::tuple()" + +- name: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + self, grad_output, weight: linear_double_backward(grads, self, grad_output, weight) + +#mps +- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight, bias: "grad.defined() ? mps_convolution_backward_symint(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple()" + +- name: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) + +- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) + result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1) + output_differentiability: [True, False] + +- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1) + result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1) + output_differentiability: [True, False] + +- name: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + self: max_pool_double_backward(grad, indices, 2) + indices: non_differentiable + result: auto_linear + +- name: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + self: max_pool_double_backward(grad, indices, 3) + indices: non_differentiable + result: auto_linear + +- name: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups) + +# TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution. +# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context +# by convolution_backward instead of being passed along from the forward pass. +- name: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32) + +- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) + result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + result2: convolution_backward_jvp_grad_bias(grad_output_t, result2) + +- name: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + input, weight, bias: "grad.defined() ? convolution_backward_overrideable_symint(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()" + +- name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) + +- name: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()" + +- name: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + self, weight, bias: "grad.defined() ? _slow_conv2d_backward_symint(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple()" + +- name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask) + +- name: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + +- name: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()" + +- name: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +- name: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +- name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: im2col(grad, kernel_size, dilation, padding, stride) + result: auto_linear + +- name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + self: col2im_symint(grad, {self.sym_size(-2), self.sym_size(-1)}, kernel_size, dilation, padding, stride) + result: auto_linear + +- name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: _adaptive_avg_pool2d_symint(grad, {grad_output.sym_size(-2), grad_output.sym_size(-1)}) + self: zeros_like(self) + result: _adaptive_avg_pool2d_backward(grad_output_t, self_p) + +- name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) }) + self: zeros_like(self) + result: _adaptive_avg_pool3d_backward(grad_output_t, self_p) + +- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + result: auto_linear + +- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + result: auto_linear + +- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + self: zeros_like(self) + result: avg_pool2d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +- name: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + self: zeros_like(self) + result: avg_pool3d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + +- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result) + self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result) + result: elu_backward(grad_output_t, alpha, scale, input_scale, is_result, self_or_result_p) + elu_double_backward(self_or_result_t, grad_output_p, alpha, scale, input_scale, is_result, self_or_result_p) + +- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + result: auto_linear + +- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + result: auto_linear + +- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + grad_output: glu_double_backward_grad_output(grad, self, dim) + self: glu_double_backward(grad, grad_output, self, dim) + result: glu_backward_jvp(result, grad_output_p, self_p, grad_output_t, self_t, dim) + +- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + grad_output: hardtanh_backward(grad, self, min_val, max_val) + self: zeros_like(grad) + result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result)) + +- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + grad_output: log_sigmoid_backward(grad, self, buffer) + self: log_sigmoid_double_backward(grad * grad_output, self) + result: log_sigmoid_backward(grad_output_t, self_p, buffer) + log_sigmoid_double_backward(self_t * grad_output_p, self_p) + +- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true) + output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype()) + +- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + # self_is_result is always false here since double backward call is an out-of-place call, self is input itself + grad_output: leaky_relu_backward(grad, self, negative_slope, false) + self: zeros_like(grad) + # leaky_relu_backward(grad_output, self, negative_slope, false) + # computes grad_output * at::where(self_p > 0, 1, negative_slope) + # so the jvp formula is the following: + # grad_output_t * at::where(self_p > 0, self_p.new_ones([]), negative_slope); + # + # leaky_relu_backward(grad_output, result, negative_slope, true) + # computes grad_output * at::where(result > 0, 1, negative_slope) + # under the assumption that `negative_slope` is positive (otherwise, + # it is not possible to compute the gradient). + # + # so the jvp formula is the following: + # grad_output_t * at::where(result_p > 0, result_p.new_ones([]), negative_slope); + # with the assumption that negative_slope is positive. + # + # Combined together that results in the following optimized kernel which + # also checks the assumption that negative_slope is positive when self_is_result + # is True: + result: leaky_relu_backward(grad_output_t, self_p, negative_slope, self_is_result) + +# This derivative is mps-only, and `error_for_max_pool2d_double_backward` just raises an error. +- name: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + grad_output: error_for_max_pool2d_double_backward() + self: zeros_like(self) + result: auto_linear + +- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 2) + self: zeros_like(self) + indices: non_differentiable + result: auto_linear + +- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + grad_output: max_pool_double_backward(grad, indices, 3) + self: zeros_like(self) + indices: non_differentiable + result: auto_linear + +- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + grad_output: mse_loss_backward(grad, self, target, reduction) + self: mse_loss_double_backward(grad * grad_output, self, reduction) + target: -mse_loss_double_backward(grad * grad_output, target, reduction) + result: " mse_loss_double_backward(self_t * grad_output_p, self_p, reduction) + - mse_loss_double_backward(target_t * grad_output_p, target_p, reduction) + + mse_loss_backward(grad_output_t, self_p, target_p, reduction) + " + +- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + grad_output: nll_loss_symint(grad, target, weight, reduction, ignore_index) + self: zeros_like(grad) + target: non_differentiable + +- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + grad_output: nll_loss2d_symint(grad, target, weight, reduction, ignore_index) + self: zeros_like(grad) + target: non_differentiable + +- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor + # self_is_result is always false here since double backward call is an out-of-place call, self is input itself + grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false) + self: zeros_like(grad) + result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false) + +- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + grad_output: reflection_pad1d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad1d_backward_symint(grad_output_t, self_p, padding) + +- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + grad_output: reflection_pad2d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad2d_backward_symint(grad_output_t, self_p, padding) + +- name: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + grad_output: reflection_pad3d_symint(grad, padding) + self: zeros_like(self) + result: reflection_pad3d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + grad_output: replication_pad1d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad1d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + grad_output: replication_pad2d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad2d_backward_symint(grad_output_t, self_p, padding) + +- name: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + grad_output: replication_pad3d_symint(grad, padding) + self: zeros_like(self) + result: replication_pad3d_backward_symint(grad_output_t, self_p, padding) + +- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + self, mat1, mat2: "sparse_sampled_addmm_backward(grad, + self, + wrap_opt_if(mat1, grad_input_mask[2]), + wrap_opt_if(mat2, grad_input_mask[1]), + alpha, beta, grad_input_mask)" + +- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) + output_differentiability: [True, False] + self, other: "grad.defined() ? _sparse_mm_reduce_impl_backward(self, grad, other, reduce, result1, grad_input_mask) : std::tuple()" + +- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta) + self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta) + result: " smooth_l1_loss_double_backward(self_t * grad_output_p, self_p, target_p, reduction, beta) + - smooth_l1_loss_double_backward(target_t * grad_output_p, self_p, target_p, reduction, beta) + + smooth_l1_loss_backward(grad_output_t, self_p, target_p, reduction, beta) + " + +- name: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor + grad_output: huber_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, delta) + self: huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) + target: -huber_loss_double_backward(grad * grad_output, self, target, reduction, delta) + +- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor + grad_output: softplus_backward(grad, self, beta, threshold) + self: softplus_double_backward(grad * grad_output, self, beta, threshold) + result: "softplus_backward(grad_output_t, self_p, beta, threshold) + + softplus_double_backward(self_t * grad_output_p, self_p, beta, threshold)" + +- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype) + output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype()) + +- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction) + self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction) + +- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + grad_output: softshrink_backward(grad, self, lambd) + self: zeros_like(grad) + result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result)) + +- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + grad_output: threshold_backward(grad, self, threshold) + self: zeros_like(grad) + result: zeros_like(self_t) + threshold_backward(grad_output_t, self_p, threshold) + +- name: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scales) + result: auto_linear + +- name: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + +- name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + grad_output: upsample_nearest1d_symint(grad, output_size, scales) + result: auto_linear + +- name: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scales) + result: auto_linear + +- name: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_nearest2d_symint(grad, output_size, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scales_h, scales_w) + result: auto_linear + +- name: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: upsample_nearest3d_symint(grad, output_size, scales_d, scales_h, scales_w) + result: auto_linear + +- name: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w) + result: auto_linear + +- name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + grad_output: sigmoid_backward(grad, output.conj()) + output: grad.conj() * grad_output * (-2 * output.conj() + 1) + result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1) + +- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor + grad_output: tanh_backward(grad, output.conj()) + output: grad.conj() * (-2 * output.conj() * grad_output) + result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p) + +# cudnn +- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) + +- name: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity) + +- name: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})" + +- name: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight: "grad.defined() ? mps_convolution_transpose_backward_symint(self, grad, weight, padding, output_padding, stride, dilation, groups, grad_input_mask) : std::tuple()" + +- name: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})" + +- name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple()" + +- name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W) + +# NB: Why is the backwards here so complicated? CuDNN cannot be used to compute +# backward in evaluation mode, because the math for backward in evaluation mode +# is different (since the forward math is different), and CuDNN does not support +# it. And in any case, you shouldn't be using this bn in evaluation mode, +# because it should be merged into the previous convolution (left for future +# work.) +# NB2: The quotes around the gradient are needed to appease YAML parsing rules. +- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +# HACK: save_mean and save_var are going to be passed in as +# requires_grad variables (even though we'll never backprop through +# them) so we need to prevent the unpacking from triggering an error. +- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("cudnn_batch_norm_backward save_mean") + save_var: not_implemented("cudnn_batch_norm_backward save_var") + reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + +# nnpack + +- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here. + input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()" + +#LSTM MPS +- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + output_differentiability: [True, True, True, False, False, False] + input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)" + +- name: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) + + + +# Only frst three of _cudnn_rnn outputs can have gradients. +# _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf) +- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dropout_state: non_differentiable + output_differentiability: [True, True, True, False, False] + input, hx, cx, weight: "_cudnn_rnn_backward_symint(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + +- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dropout_state: non_differentiable + input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg) + +# miopen + +- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()" + +- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("miopen_batch_norm_backward save_mean") + save_var: not_implemented("miopen_batch_norm_backward save_var") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + +- name: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + dropout_state: non_differentiable + output_differentiability: [True, True, True, False, False] + input, hx, cx, weight: "miopen_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" + +- name: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + dropout_state: non_differentiable + +- name: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) + output_differentiability: [True, True, True, False] + input, weight0, weight1, weight2, weight3, hx_, cx_: "GradMode::is_enabled() ? mkldnn_rnn_layer_differentiable_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3) : mkldnn_rnn_layer_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3)" + +- name: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + +# mkldnn +- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" + +- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor + self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask) + +- name: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: mkldnn_max_pool2d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + self: mkldnn_max_pool3d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode) + +- name: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + self: mkldnn_adaptive_avg_pool2d_backward(grad, self) + +- name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + +# NestedTensor +- name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" + +- name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor + t: grad.to_padded_tensor_symint(0, t.sym_sizes()) + mask: non_differentiable + +- name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213) + cpu_nested_shape_example: non_differentiable + +- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt, std::optional(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())" + padding: non_differentiable + +- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef(padded.sym_sizes())) + offsets: non_differentiable + dummy: non_differentiable + +- name: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) + self: grad.values() + nested_size: non_differentiable + nested_strides: non_differentiable + +- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) + self: grad.values() + offsets: non_differentiable + lengths: non_differentiable + dummy: non_differentiable + min_seqlen: non_differentiable + max_seqlen: non_differentiable + +- name: _nested_get_values(Tensor(a) self) -> Tensor(a) + self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt)" + +# Transformer +- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true)) + +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + output_differentiability: [True, False, False, False] + query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale) + +- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale) + +- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) + output_differentiability: [True, False] + query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale) + +- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False] + query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right) + +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + output_differentiability: [True, False, False, False, False, False] + query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) + +- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + +- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) + +- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + output_differentiability: [True, False, False, False, False, False, False, False, False] + query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale) + +# fft +- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) + result: auto_linear + +- name: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + self: fft_c2r_backward(grad, dim, normalization) + result: auto_linear + +- name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + self: _fft_c2c_symint(grad, dim, normalization, !forward) + result: auto_linear + +- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] + dispatch: + Default: + self: unbind_backward(grads, dim) + result: auto_linear + AutogradNestedTensor: + self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())" + result: auto_linear + +- name: stack(Tensor[] tensors, int dim=0) -> Tensor + tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors)) + result: stack_jvp(tensors, dim) + +# fused RNN kernels + +# Only frst two of _thnn_fused_lstm_cell outputs can have gradients. +# _thnn_fused_lstm_cell outputs: (hy, cy, workspace) +- name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) + output_differentiability: [True, True, False] + input_gates, hidden_gates, cx, input_bias, hidden_bias: "GradMode::is_enabled() ? _thnn_differentiable_lstm_cell_backward(grads[0], grads[1], input_gates, hidden_gates, input_bias, hidden_bias, cx, result1) : _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())" + +- name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) + input_gates, hidden_gates, hx, input_bias, hidden_bias: "grad.defined() ? (GradMode::is_enabled() ? _thnn_differentiable_gru_cell_backward(grad, input_gates, hidden_gates, hx, input_bias, hidden_bias) : _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())) : std::tuple()" + +# PackedSequence helpers +- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first) + +# TH wrappers +- name: eq.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: eq.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: ge.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: ge.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: gt.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: gt.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: le.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: le.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: lt.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: lt.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: ne.Scalar(Tensor self, Scalar other) -> Tensor + output_differentiability: [False] + +- name: ne.Tensor(Tensor self, Tensor other) -> Tensor + output_differentiability: [False] + +- name: multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + output_differentiability: [False] + +- name: nonzero(Tensor self) -> Tensor + output_differentiability: [False] + +- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial) + +- name: _pin_memory(Tensor self, Device? device=None) -> Tensor + self: grad + +- name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor + self: non_differentiable + other: non_differentiable + output_differentiability: [False] + +- name: _test_warn_in_autograd(Tensor self) -> Tensor + self: warn_backwards(grad) + +- name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + dispatch: + Default: + self: grad.expand_symint(self.sym_sizes()) + 1 + result: auto_linear + AutogradNestedTensor: + self: grad.mul(grad) + AutogradCUDA: + self: grad.expand_symint(self.sym_sizes()) * 2 + +- name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + dispatch: + AutogradNestedTensor: + self: grad.mul(grad).add(grad) + +- name: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) + dispatch: + Default: + self: grad.reshape_as(self) + AutogradCUDA: + self: grad.reshape_as(self) + 1 + +- name: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + output_differentiability: [False] + +- name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor + self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result) + index: non_differentiable + result: scatter_reduce_jvp(self_p, self_t, dim, index, src_p, src_t, reduce, include_self, result) + +- name: special_airy_ai(Tensor x) -> Tensor + x: non_differentiable + +- name: special_bessel_j0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_j1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_y0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_bessel_y1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_modified_bessel_i0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_i1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_k0(Tensor self) -> Tensor + self: non_differentiable + +- name: special_modified_bessel_k1(Tensor self) -> Tensor + self: non_differentiable + +- name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + x: non_differentiable + +- name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + x: non_differentiable + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + n: non_differentiable + +- name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + x: non_differentiable + +- name: special_spherical_bessel_j0(Tensor x) -> Tensor + x: non_differentiable + +- name: _reshape_copy(Tensor self, SymInt[] size) -> Tensor + self: grad.reshape_symint(self.sym_sizes()) + result: auto_linear + +# note(crcrpar): `torchgen/api/autograd` logic would unwantedly replace substrings of `self` and `other` of function names. +- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] + self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type()) + other: div_tensor_other_backward(grads[i], self[i], other[i]) + result: (self_t - other_t * result[i]) / other_p + +- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] + self: pow_backward_self(grads[i], self[i], exponent[i]) + exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i]) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj() + +- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] + self: pow_backward(grads[i], self[i], exponent[i]) + result: pow_backward(self_t.conj(), self_p, exponent[i]).conj() + +- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] + exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i]) + +# note(crcrpar): following definitions seem necessary because the reference native functions +# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument. +- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0) + result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar) + +- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0) + result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) + +- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0) + result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar) + +- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0) + result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i]) + +# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle: +# formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)` +- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] + self: norm_backward(grads[i], self[i], ord, result[i]) + result: norm_jvp(self_p, self_t, ord, result[i]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py new file mode 100644 index 0000000000000000000000000000000000000000..2f61209fa6fd0041b732f1400e1162d2f124ad34 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py @@ -0,0 +1,134 @@ +""" +For procedural tests needed for __torch_function__, we use this function +to export method names and signatures as needed by the tests in +test/test_overrides.py. + +python -m tools.autograd.gen_annotated_fn_args \ + aten/src/ATen/native/native_functions.yaml \ + aten/src/ATen/native/tags.yaml \ + $OUTPUT_DIR \ + tools/autograd + +Where $OUTPUT_DIR is where you would like the files to be +generated. In the full build system, OUTPUT_DIR is +torch/testing/_internal/generated +""" + +from __future__ import annotations + +import argparse +import os +import textwrap +from collections import defaultdict +from typing import Any, TYPE_CHECKING + +import torchgen.api.python as python +from torchgen.context import with_native_function +from torchgen.gen import parse_native_yaml +from torchgen.utils import FileManager + +from .gen_python_functions import ( + is_py_fft_function, + is_py_linalg_function, + is_py_nn_function, + is_py_special_function, + is_py_torch_function, + is_py_variable_method, + should_generate_py_binding, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torchgen.model import Argument, BaseOperatorName, NativeFunction + + +def gen_annotated( + native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str +) -> None: + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + mappings = ( + (is_py_torch_function, "torch._C._VariableFunctions"), + (is_py_nn_function, "torch._C._nn"), + (is_py_linalg_function, "torch._C._linalg"), + (is_py_special_function, "torch._C._special"), + (is_py_fft_function, "torch._C._fft"), + (is_py_variable_method, "torch.Tensor"), + ) + annotated_args: list[str] = [] + for pred, namespace in mappings: + groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list) + for f in native_functions: + if not should_generate_py_binding(f) or not pred(f): + continue + groups[f.func.name.name].append(f) + for group in groups.values(): + for f in group: + annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") + + template_path = os.path.join(autograd_dir, "templates") + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template( + "annotated_fn_args.py", + "annotated_fn_args.py.in", + lambda: { + "annotated_args": textwrap.indent("\n".join(annotated_args), " "), + }, + ) + + +@with_native_function +def gen_annotated_args(f: NativeFunction) -> str: + def _get_kwargs_func_exclusion_list() -> list[str]: + # functions that currently don't work with kwargs in test_overrides.py + return [ + "diagonal", + "round_", + "round", + "scatter_", + ] + + def _add_out_arg( + out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool + ) -> None: + for arg in args: + if arg.default is not None: + continue + out_arg: dict[str, Any] = {} + out_arg["is_kwarg_only"] = str(is_kwarg_only) + out_arg["name"] = arg.name + out_arg["simple_type"] = python.argument_type_str( + arg.type, simple_type=True + ) + size_t = python.argument_type_size(arg.type) + if size_t: + out_arg["size"] = size_t + out_args.append(out_arg) + + out_args: list[dict[str, Any]] = [] + _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False) + if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list(): + _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True) + + return f"{f.func.name.name}: {repr(out_args)}," + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") + parser.add_argument( + "native_functions", metavar="NATIVE", help="path to native_functions.yaml" + ) + parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml") + parser.add_argument("out", metavar="OUT", help="path to output directory") + parser.add_argument( + "autograd", metavar="AUTOGRAD", help="path to template directory" + ) + args = parser.parse_args() + gen_annotated(args.native_functions, args.tags, args.out, args.autograd) + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_autograd.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..d93d3f4cab4a6f37c0c81c548b4da3b6c5b9dc95 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_autograd.py @@ -0,0 +1,147 @@ +""" +To run this file by hand from the root of the PyTorch +repository, run: + +python -m tools.autograd.gen_autograd \ + aten/src/ATen/native/native_functions.yaml \ + aten/src/ATen/native/tags.yaml \ + $OUTPUT_DIR \ + tools/autograd + +Where $OUTPUT_DIR is where you would like the files to be +generated. In the full build system, OUTPUT_DIR is +torch/csrc/autograd/generated/ +""" + +# gen_autograd.py generates C++ autograd functions and Python bindings. +# +# It delegates to the following scripts: +# +# gen_autograd_functions.py: generates subclasses of torch::autograd::Node +# gen_variable_type.py: generates VariableType.h which contains all tensor methods +# gen_python_functions.py: generates Python bindings to THPVariable +# + +from __future__ import annotations + +import argparse +import os + +from torchgen.api import cpp +from torchgen.api.autograd import ( + match_differentiability_info, + NativeFunctionWithDifferentiabilityInfo, +) +from torchgen.gen import parse_native_yaml +from torchgen.selective_build.selector import SelectiveBuilder + +from . import gen_python_functions +from .gen_autograd_functions import ( + gen_autograd_functions_lib, + gen_autograd_functions_python, +) +from .gen_inplace_or_view_type import gen_inplace_or_view_type +from .gen_trace_type import gen_trace_type +from .gen_variable_factories import gen_variable_factories +from .gen_variable_type import gen_variable_type +from .gen_view_funcs import gen_view_funcs +from .load_derivatives import load_derivatives + + +def gen_autograd( + native_functions_path: str, + tags_path: str, + out: str, + autograd_dir: str, + operator_selector: SelectiveBuilder, + disable_autograd: bool = False, +) -> None: + # Parse and load derivatives.yaml + differentiability_infos, used_dispatch_keys = load_derivatives( + os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path + ) + + template_path = os.path.join(autograd_dir, "templates") + + native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions + fns = sorted( + filter( + operator_selector.is_native_function_selected_for_training, native_funcs + ), + key=lambda f: cpp.name(f.func), + ) + fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo] = ( + match_differentiability_info(fns, differentiability_infos) + ) + + # Generate VariableType.h/cpp + if not disable_autograd: + gen_variable_type( + out, + native_functions_path, + tags_path, + fns_with_diff_infos, + template_path, + used_dispatch_keys, + ) + + gen_inplace_or_view_type( + out, native_functions_path, tags_path, fns_with_diff_infos, template_path + ) + + # operator filter not applied as tracing sources are excluded in selective build + gen_trace_type(out, native_funcs, template_path) + # Generate Functions.h/cpp + gen_autograd_functions_lib(out, differentiability_infos, template_path) + + # Generate variable_factories.h + gen_variable_factories(out, native_functions_path, tags_path, template_path) + + # Generate ViewFuncs.h/cpp + gen_view_funcs(out, fns_with_diff_infos, template_path) + + +def gen_autograd_python( + native_functions_path: str, + tags_path: str, + out: str, + autograd_dir: str, +) -> None: + differentiability_infos, _ = load_derivatives( + os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path + ) + + template_path = os.path.join(autograd_dir, "templates") + + # Generate Functions.h/cpp + gen_autograd_functions_python(out, differentiability_infos, template_path) + + # Generate Python bindings + deprecated_path = os.path.join(autograd_dir, "deprecated.yaml") + gen_python_functions.gen( + out, native_functions_path, tags_path, deprecated_path, template_path + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate autograd C++ files script") + parser.add_argument( + "native_functions", metavar="NATIVE", help="path to native_functions.yaml" + ) + parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml") + parser.add_argument("out", metavar="OUT", help="path to output directory") + parser.add_argument( + "autograd", metavar="AUTOGRAD", help="path to autograd directory" + ) + args = parser.parse_args() + gen_autograd( + args.native_functions, + args.tags, + args.out, + args.autograd, + SelectiveBuilder.get_nop_selector(), + ) + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..d32562374d5f6e85cad18f314fbbf2d3cf415985 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py @@ -0,0 +1,1076 @@ +# Generates C++ autograd functions for the derivatives of ATen operations +# +# This writes two files: +# Functions.h/cpp: subclasses of autograd::Node +# python_functions.h/cpp: Python bindings for the above classes +# + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from torchgen.api.autograd import ( + Derivative, + DifferentiabilityInfo, + SavedAttribute, + uses_retain_variables, + uses_single_grad, +) +from torchgen.api.types import ( + ArrayRefCType, + BaseCppType, + BaseCType, + Binding, + boolT, + doubleT, + intArrayRefT, + iTensorListRefT, + ListCType, + longT, + MutRefCType, + OptionalCType, + optionalIntArrayRefT, + optionalSymIntArrayRefT, + scalarT, + stringT, + symIntArrayRefT, + SymIntT, + TENSOR_LIST_LIKE_CTYPES, + tensorListT, + tensorT, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.model import Argument, FunctionSchema +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import VIEW_FUNCTIONS + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +FUNCTION_DECLARATION = CodeTemplate( + """\ +#ifdef _WIN32 +struct ${op} : public ${superclass} { + TORCH_API ${op}() = default; +#else +struct TORCH_API ${op} : public ${superclass} { +#endif + using ${superclass}::${superclass}; + variable_list apply(variable_list&& grads) override; + std::string name() const override { return "${op}"; } + void release_variables() override { + ${thread_lock} + ${release_variables} + } + ${will_release_variables} + void compiled_args(CompiledNodeArgs& args) const override; + variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override; + ${saved_variables} + ${saved_list_sizes} +}; +""" +) + +WILL_RELEASE_VARIABLES = CodeTemplate( + """\ +bool retain_variables = true; +void will_release_variables() override { + retain_variables = false; +} +""" +) + +# We generate e.g. MulBackward0::apply and have that call into +# MulBackward0_apply_functional. The apply_functional is a pure function, +# that is, it does not rely on global state. MulBackward0::apply +# is responsible for querying the autograd engine for which outputs should +# be computed (needs_input_grad), applying locks, +# and unpacking saved variables to pass to MulBackward0_apply_functional. +# +# needs_input_grad is a mapping from input index to if that input needs +# gradients computed. For operators that take in List[Tensor], the List[Tensor] +# is one element in the needs_input_grad that specifies if *any* of the +# List[Tensor] needs input grad. In theory this could be optimized. +FUNCTION_DEFINITION = CodeTemplate( + """\ +static variable_list ${op}_apply_functional( + variable_list&& grads, + std::array needs_input_grad${,apply_functional_args_signature}) +{ + IndexRangeGenerator gen; + ${compute_index_ranges} + variable_list grad_inputs(gen.size()); + ${body} + return grad_inputs; +} +inline variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args) +{ +#ifdef C10_MOBILE + TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile"); +#else + auto packed_args = PackedArgs(args); + auto needs_input_grad = packed_args.unpack>(); + ${unpack_ivalues} + return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args}); +#endif +} + +variable_list ${op}::apply(variable_list&& grads) { + ${thread_lock} + ${asserts} + ${unpacks} + ${compute_needs_input_grad} + return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args}); +} + +void ${op}::compiled_args(CompiledNodeArgs& args) const { + ${compiled_args} +} +variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) { +#ifdef C10_MOBILE + TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile"); +#else + ${apply_with_saved_before} + + static bool called = false; + if (!called) { + called = true; + ${compute_schema} + const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface(); + pyinterface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema); + } + + variable_list output_result; + + PackedArgs packed_args; + ${asserts} + ${unpacks} + ${compute_needs_input_grad} + packed_args.pack(needs_input_grad); + ${get_packed_args} + + output_result = compiled_autograd_apply_functional(packed_args, next_edges(), saved, grads, name()); + + ${apply_with_saved_after} + return output_result; +#endif +} + +""" +) + +GRAD_INPUT_MASK = CodeTemplate( + """\ + auto grad_input_mask = std::array{ + ${masks} + }; +""" +) + +COMPUTE_NEEDS_INPUT_GRAD = CodeTemplate( + """\ +IndexRangeGenerator gen; +${compute_index_ranges} +auto needs_input_grad = std::array{ + ${masks} +};\ +""" +) + + +DERIVATIVE_SINGLE = CodeTemplate( + """\ +if (needs_input_grad[/*${name}*/${idx}]) { + auto grad_result = ${derivative}; + copy_range(grad_inputs, ${name}_ix, grad_result); +} +""" +) + +# note(crcrpar): `self` argument and other optional positional argument +# of foreach functions are basically a list of n `Tensor`s thus iterating over +# `grads` in order to utilize and apply the existing derivative definitions +# to each `Tensor`(s) of `self`, and the others. +DERIVATIVE_SINGLE_FOREACH = CodeTemplate( + """\ +if (needs_input_grad[/*${name}*/${idx}]) { // ${name} + std::vector grad_result; + grad_result.reserve(grads.size()); + for (const auto & i : c10::irange(grads.size())) { + if (grads[i].defined()) { + grad_result.emplace_back(${derivative}); + } else { + grad_result.emplace_back(Tensor()); + } + } + copy_range(grad_inputs, ${name}_ix, grad_result); +} +""" +) + +DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate( + """\ + if (needs_input_grad[/*${name}*/${idx}]) { + copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); + } +""" +) + +DERIVATIVE_MULTI = CodeTemplate( + """\ +if (${needs_input_grad}) { + ${grad_input_mask} + auto grad_result = ${derivative}; + ${copy_ranges} +} +""" +) + +# Generates python bindings +# +# This generates the definitions for: +# (1) The PyTypeObject for each backward grad_fn subclassing Node +# (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs) +# We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs +# Each PyGetSetDef has a function ptr to a getter, also defined here (3). +# (3) Getters for each of grad_fn's saved inputs and outputs. +# +PY_FUNCTION_DEFINITION = CodeTemplate( + """\ +static PyTypeObject ${op}Class; +addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties); +""" +) + +PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate( + """\ +${all_getter_definitions} + +static struct PyGetSetDef ${op}_properties[] = { + THP_FUNCTION_DEFAULT_PROPERTIES, + ${all_getsetdef_structs} + {nullptr} /* sentinel */ +}; + +""" +) + +PY_GETSETDEF_STRUCT = CodeTemplate( + """\ +{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}""" +) + +PY_RAW_GETSETDEF_STRUCT = CodeTemplate( + """\ +{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}""" +) + +# Getter templates +GETTER_DEFINITION = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + auto prop = static_cast<${op}*>(self->cdata.get())->${name}; + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_SAVEDVAR = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_; + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto *node = static_cast<${op}*>(self->cdata.get()); + const auto& prop = node->${name}_; + if (node->${name}_released_) { + PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); + return nullptr; + } + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto *node = static_cast<${op}*>(self->cdata.get()); + const auto& prop = node->${name}_; + if (node->${name}_released_) { + PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); + return nullptr; + } + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_OPT = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; + if (!opt_prop.has_value()) { + Py_RETURN_NONE; + } + auto prop = opt_prop.value(); + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate( + """\ +static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name}; + if (!opt_prop.list.has_value()) { + Py_RETURN_NONE; + } + auto prop = opt_prop.list.value(); + ${body} + END_HANDLE_TH_ERRORS +} +""" +) + +# Getter body +GETTER_BODY_SAVEDVAR = """\ +return THPVariable_Wrap(prop.unpack(self->cdata)); +""" + +GETTER_BODY_RAW_SAVEDVAR = """\ +pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference); +return obj.release().ptr(); +""" + +GETTER_BODY_VEC_SAVEDVAR = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i: c10::irange(prop.size())) { + PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata))); +} +return tup; +""" + +GETTER_BODY_RAW_VEC_SAVEDVAR = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference); + PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr()); +} +return tup; +""" + +GETTER_BODY_ARRAYREF_LONG = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i])); +} +return tup; +""" + +GETTER_BODY_ARRAYREF_SYMINT = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + auto si = prop[i]; + if (auto m = si.maybe_as_int()) { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m)); + } else { + auto py_symint = py::cast(si).release().ptr(); + PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint); + } +} +return tup; +""" + +GETTER_BODY_ARRAYREF_DOUBLE = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i : c10::irange(prop.size())) { + PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i])); +} +return tup; +""" + +GETTER_BODY_INT64_T = """\ +return PyLong_FromUnsignedLong((int64_t) prop); +""" + +GETTER_BODY_SYMINT = """\ +if (auto m = prop.maybe_as_int()) { + return PyLong_FromUnsignedLong(*m); +} else { + return py::cast(prop).release().ptr(); +} +""" + +GETTER_BODY_DOUBLE = """\ +return PyFloat_FromDouble((double) prop); +""" + +GETTER_BODY_BOOL = """\ +if (prop) { + Py_RETURN_TRUE; +} else { + Py_RETURN_FALSE; +} +""" + +GETTER_BODY_STRING = """\ +return PyUnicode_FromStringAndSize(prop.data(), prop.size()); +""" + +GETTER_BODY_SCALAR = """\ +if (prop.isComplex()) { + auto cprop = prop.to>(); + return PyComplex_FromDoubles(cprop.real(), cprop.imag()); +} else if (prop.isFloatingPoint()) { + return PyFloat_FromDouble(prop.to()); +} else if (prop.isIntegral(/*includeBool=*/false)) { + return PyLong_FromLong(prop.to()); +} else if (prop.isBoolean()) { + if (prop.to()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} else { + PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); + return nullptr; +} +""" + + +GETTER_BODY_VEC_SCALAR = """\ +PyObject* tup = PyTuple_New((Py_ssize_t) prop.size()); +for (auto i: c10::irange(prop.size())) { + if (prop[i].isComplex()) { + auto cprop = prop[i].to>(); + PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag())); + } else if (prop[i].isFloatingPoint()) { + auto double_prop = prop[i].to(); + PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop)); + } else if (prop[i].isIntegral(/*includeBool=*/false)) { + auto long_prop = prop[i].to(); + PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop)); + } else if (prop[i].isBoolean()) { + if (prop[i].to()) { + PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True); + } else { + PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False); + } + } else { + PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type"); + return nullptr; + } +} +return tup; +""" + + +MISC_GETTER_DEFS = { + OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T), + OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT), + BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE), + OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE), + BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL), + BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR), + OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR), +} + +# These functions have backwards which cannot be traced, and so must have +# their backward functions traced opaquely. +# VIEW_FUNCTIONS are not traceable because they use as_strided, which +# has an untraceable backwards, see +# https://github.com/pytorch/pytorch/issues/4250 +# TODO: This is probably not exhaustive, but it's a start +UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS + + +def get_infos_with_derivatives_list( + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], +) -> list[DifferentiabilityInfo]: + diff_info_list = [ + info + for diffinfo_dict in differentiability_infos.values() + for info in diffinfo_dict.values() + ] + + return list(filter(lambda info: info.args_with_derivatives, diff_info_list)) + + +def gen_autograd_functions_lib( + out: str, + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + template_path: str, +) -> None: + """Functions.h and Functions.cpp body + + These contain the auto-generated subclasses of torch::autograd::Node + for each every differentiable torch function. + """ + + # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here + # infos with the diff dispatchkeys but the same name will still be in the same shard. + infos = get_infos_with_derivatives_list(differentiability_infos) + declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos] + definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos] + + file_basename = "Functions" + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in [".h", ".cpp"]: + fname = file_basename + suffix + fm.write_with_template( + fname, + fname, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{fname}", + "autograd_function_declarations": declarations, + "autograd_function_definitions": definitions, + }, + ) + + +def gen_autograd_functions_python( + out: str, + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + template_path: str, +) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + num_shards = 5 + fm.write( + "python_functions.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/python_functions.h", + "shard_forward_declare": [ + f"void initialize_autogenerated_functions_{i}(PyObject* module);" + for i in range(num_shards) + ], + "shard_call": [ + f"initialize_autogenerated_functions_{i}(module);" + for i in range(num_shards) + ], + }, + ) + + # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here + # infos with the diff dispatchkeys but the same name will still be in the same shard. + infos = get_infos_with_derivatives_list(differentiability_infos) + fm.write_sharded( + "python_functions.cpp", + infos, + key_fn=lambda info: info.name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/python_functions.cpp", + }, + env_callable=lambda info: { + "py_function_initializers": [ + process_function(info, PY_FUNCTION_DEFINITION) + ], + "py_function_props_and_getters": [ + process_function(info, PY_FUNCTION_PROPS_AND_GETTERS) + ], + }, + num_shards=num_shards, + sharded_keys={"py_function_initializers", "py_function_props_and_getters"}, + ) + + +def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str: + saved_variables: list[str] = [] + release_variables: list[str] = [] + saved_list_sizes: list[str] = [] + unpack: list[str] = [] + asserts: list[str] = [] + compute_index_ranges: list[str] = [] + getter_definitions: list[str] = [] + py_getsetdef_structs: list[str] = [] + compiled_args: list[str] = [] + apply_with_saved_before: list[str] = [] + apply_with_saved_after: list[str] = [] + apply_functional_args: list[str] = [] + apply_functional_args_ref_types: list[str] = [] + # Maps the name of an input (to the original forward operator; + # examples are "self", "other") to the order in which they appear in the + # operator. + # For example; if the operator is foo(Tensor self, int64_t k, Tensor other), + # the mapping is: {"self": 0, "other": 1}. + # We use this mapping to populate needs_input_grad in some order and then grab + # values from it. + input_name_to_idx: dict[str, int] = {} + + for idx, arg in enumerate(info.args_with_derivatives): + if arg.type in TENSOR_LIST_LIKE_CTYPES: + size = f"{arg.name}_size_" + saved_list_sizes.append(f"size_t {arg.name}_size_;") + apply_functional_args.append(f"{arg.name}_size_") + apply_functional_args_ref_types.append("size_t") + else: + size = "1" + compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});") + input_name_to_idx[arg.name] = idx + + def save_var(var: SavedAttribute, is_output: bool) -> None: + name = var.nctype.name + type = var.nctype.type + should_append_getsetdef = True + should_append_raw_getsetdef = False + visit_name = name + uses_cpp_saved_variable_cls = False + unpacked_ref_type = None + + if ( + type == BaseCType(tensorT) + or type == OptionalCType(BaseCType(tensorT)) + or type == MutRefCType(OptionalCType(BaseCType(tensorT))) + or (type == BaseCType(scalarT) and is_output) + ): + uses_cpp_saved_variable_cls = True + saved_variables.append(f"SavedVariable {name}_;") + release_variables.append(f"{name}_.reset_data();") + ptr = "shared_from_this()" if is_output else "" + unpack.append(f"auto {name} = {name}_.unpack({ptr});") + getter_definitions.append( + GETTER_DEFINITION_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_SAVEDVAR + ) + ) + getter_definitions.append( + GETTER_DEFINITION_RAW_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR + ) + ) + should_append_raw_getsetdef = True + visit_name = f"{name}_" + unpacked_ref_type = "Tensor&" + elif ( + type == BaseCType(tensorListT) + or type == BaseCType(iTensorListRefT) + or type == VectorCType(BaseCType(tensorT)) + ): + # note(crcrpar): [nuanced return type of out-of-place foreach functions] + # When an out-of-place foreach function whose return signature is `Tensor[]` + # spells out its backward definitions in `derivatives.yaml`, and some of them depend on + # `result`, `result`'s type is interpreted and treated as `std::vector`. + # An out-of-place foreach whose backwards rely on their output doesn't suffer from this + # difference if the definitions are codegen'ed. + # This special case is needed for `_foreach_pow.List` and `_foreach_pow.ScalarAndTensor` + # as of https://github.com/pytorch/pytorch/pull/105504. + if type == VectorCType(BaseCType(tensorT)): + assert ( + info.func.func.name.name.base.startswith("_foreach") and is_output + ) + uses_cpp_saved_variable_cls = True + saved_variables.append(f"std::vector {name}_;") + saved_variables.append(f"bool {name}_released_ = false;") + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f"{name}_.clear();") + release_variables.append(f"{name}_released_ = true;") + ptr = "shared_from_this()" if is_output else "nullptr" + unpack.append(f"auto {name} = unpack_list({name}_, {ptr});") + asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") + getter_definitions.append( + GETTER_DEFINITION_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR + ) + ) + getter_definitions.append( + GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR + ) + ) + should_append_raw_getsetdef = True + visit_name = f"{name}_" + unpacked_ref_type = "std::vector&" + elif type == ListCType(OptionalCType(BaseCType(tensorT))): + uses_cpp_saved_variable_cls = True + saved_variables.append(f"std::vector {name}_;") + saved_variables.append(f"bool {name}_released_ = false;") + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f"{name}_.clear();") + release_variables.append(f"{name}_released_ = true;") + unpack.append(f"auto {name} = unpack_opt_list({name}_);") + asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") + getter_definitions.append( + GETTER_DEFINITION_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR + ) + ) + getter_definitions.append( + GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute( + op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR + ) + ) + should_append_raw_getsetdef = True + visit_name = f"{name}_" + unpacked_ref_type = "torch::List>&" + elif type == BaseCType(intArrayRefT): + saved_variables.append(f"std::vector {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG + ) + ) + elif type == BaseCType(symIntArrayRefT): + saved_variables.append(f"std::vector {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) + elif type == BaseCType(optionalIntArrayRefT): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG + ) + ) + elif type == BaseCType(optionalSymIntArrayRefT): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) + elif type == OptionalCType(BaseCType(intArrayRefT)): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG + ) + ) + elif type == OptionalCType(BaseCType(symIntArrayRefT)): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT + ) + ) + elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))): + saved_variables.append(f"c10::OptionalArray {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT_ARRAYREF.substitute( + op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE + ) + ) + elif type == BaseCType(longT): + saved_variables.append(f"{type.cpp_type()} {name} = 0;") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_INT64_T + ) + ) + elif type == BaseCType(SymIntT): + saved_variables.append(f"c10::SymInt {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_SYMINT + ) + ) + elif type == BaseCType(stringT): + saved_variables.append(f"std::string {name};") + getter_definitions.append( + GETTER_DEFINITION.substitute( + op=info.op, name=name, body=GETTER_BODY_STRING + ) + ) + elif type == OptionalCType(BaseCType(stringT)): + saved_variables.append(f"std::optional {name};") + getter_definitions.append( + GETTER_DEFINITION_OPT.substitute( + op=info.op, name=name, body=GETTER_BODY_STRING + ) + ) + elif type == ArrayRefCType( + elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) + ): + saved_variables.append(f"std::vector {name};") + unpacked_ref_type = "std::vector&" + saved_variables.append(f"bool {name}_released_ = false;") + # Just clear() is sufficient, we don't need to loop and clear each variable. + # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well. + release_variables.append(f"{name}.clear();") + # release_variables.append(f"{name}_released_ = true;") + # unpack.append(f"auto {name} = unpack_list({name}_);") + # asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);") + getter_definitions.append( + CodeTemplate( + """\ +static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) { + HANDLE_TH_ERRORS + const auto *node = static_cast<${op}*>(self->cdata.get()); + const auto& prop = node->${name}; + if (node->${name}_released_) { + PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE); + return nullptr; + } + ${body} + END_HANDLE_TH_ERRORS +} + """ + ).substitute( + op=info.op, + name=name, + body=GETTER_BODY_VEC_SCALAR, + ) + ) + else: + # Check for indicators that you're putting a non-owning reference + # into the saved variable field. If this is spuriously firing, + # edit this field. Otherwise, you probably need to add a case + # above. + assert ( + "ref" not in type.cpp_type().lower() + and "view" not in type.cpp_type().lower() + and "*" not in type.cpp_type() + and "&" not in type.cpp_type() + ), f"{type.cpp_type()} looks like it contains a non-owning reference" + saved_variables.append(f"{type.cpp_type()} {name};") + + if type in MISC_GETTER_DEFS: + # pyrefly: ignore [index-error] + getter_def, body = MISC_GETTER_DEFS[type] + getter_definitions.append( + getter_def.substitute(op=info.op, name=name, body=body) + ) + else: + # Types we don't expose python bindings to yet: + # TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry, + # std::vector>, std::vector + should_append_getsetdef = False + + if should_append_getsetdef: + py_getsetdef_structs.append( + PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name) + ) + if should_append_raw_getsetdef: + py_getsetdef_structs.append( + PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name) + ) + + if uses_cpp_saved_variable_cls: + compiled_args.append( + f"args.collect({visit_name}, {'true' if is_output else 'false'});" + ) + else: + compiled_args.append(f"args.collect({visit_name});") + apply_with_saved_before.append(f"saved.before({visit_name});") + apply_with_saved_after.append(f"saved.after({visit_name});") + + if unpacked_ref_type is None: + unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&" + apply_functional_args.append(str(name)) + apply_functional_args_ref_types.append(unpacked_ref_type) + + for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)): + save_var(var, is_output=False) + for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)): + save_var(var, is_output=True) + + # lock the mutex when we release variables and in Node::apply to protect thread safety + # see Note [Thread Safety on Autograd Node] + if len(release_variables) > 0: + thread_lock = "std::lock_guard lock(mutex_);" + else: + thread_lock = "" + + if uses_retain_variables(info): + apply_functional_args.append("retain_variables") + apply_functional_args_ref_types.append("bool") + will_release_variables = WILL_RELEASE_VARIABLES.substitute() + else: + will_release_variables = "" + + body: list[str] = [] + + if uses_single_grad(info): + body.append("const auto& grad = grads[0];") + else: + # Generate aliases for gradients named for returned values. + body.extend( + f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];" + for name in sorted(info.used_named_gradients) + ) + + def emit_derivative( + derivative: Derivative, + args_with_derivatives: Sequence[Binding], + ) -> tuple[bool, str]: + formula = derivative.formula + var_names = derivative.var_names + + if len(var_names) == 1: + checks_any_grad_defined = False + if "not_implemented" not in formula: + matching_args = [ + arg for arg in args_with_derivatives if arg.name == var_names[0] + ] + if len(matching_args) == 1: + # We can add undefined grad support if the input variable is a Tensor + arg = matching_args[0] + if isinstance(arg.argument, Argument) and str( + arg.argument.type + ) in ("Tensor", "Tensor?"): + formula = "any_grad_defined ? (" + formula + ") : Tensor()" + checks_any_grad_defined = True + if info.name.startswith("_foreach_"): + derivative_template = DERIVATIVE_SINGLE_FOREACH + else: + derivative_template = DERIVATIVE_SINGLE + return ( + checks_any_grad_defined, + derivative_template.substitute( + name=var_names[0], + derivative=formula, + idx=input_name_to_idx[var_names[0]], + ), + ) + + else: + if "grad_input_mask" in formula: + masks = [ + f"needs_input_grad[{input_name_to_idx[name]}]," + for name in var_names + ] + grad_input_mask = GRAD_INPUT_MASK.substitute( + n=len(var_names), masks=masks + ) + else: + grad_input_mask = "" + needs_input_grad = [ + f"needs_input_grad[{input_name_to_idx[name]}]" for name in var_names + ] + needs_input_grad = " || ".join(needs_input_grad) + copy_ranges: list[str] = [] + for i, n in enumerate(var_names): + copy_ranges.append( + DERIVATIVE_MULTI_COPY_RANGE.substitute( + name=n, i=i, idx=input_name_to_idx[n] + ) + ) + return False, DERIVATIVE_MULTI.substitute( + needs_input_grad=needs_input_grad, + copy_ranges=copy_ranges, + derivative=formula, + grad_input_mask=grad_input_mask, + ) + + masks = [] + + need_any_grad_defined_var = False + for derivative in info.derivatives: + checks_any_grad_defined, derivative_text = emit_derivative( + derivative, info.args_with_derivatives + ) + body.append(derivative_text) + need_any_grad_defined_var |= checks_any_grad_defined + + for name in input_name_to_idx: + masks.append(f"task_should_compute_output({{ {name}_ix }}),") + + # Since single-output derivative formulas need to check if grads are + # defined, only perform the check once, before all the formulas + if need_any_grad_defined_var: + body.insert( + -len(info.derivatives), + "bool any_grad_defined = any_variable_defined(grads);", + ) + + if info.name in UNTRACEABLE_FUNCTIONS: + superclass = "Node" + else: + superclass = "TraceableFunction" + + all_getsetdef_structs = ( + ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else "" + ) + all_getter_definitions = "\n".join(getter_definitions) + + compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute( + n=len(masks), compute_index_ranges=compute_index_ranges, masks=masks + ) + apply_functional_args_signature = [ + f"{T} {x}" + for T, x in zip(apply_functional_args_ref_types, apply_functional_args) + ] + get_packed_args = "\n".join( + f"packed_args.pack({name});" for name in apply_functional_args + ) + unpack_ivalues = [] + for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): + typ = typ.removesuffix("&") + # pyrefly: ignore [bad-argument-type] + unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") + + schema_args = [f"std::array"] + for typ in apply_functional_args_ref_types: + typ = typ.removesuffix("&") + typ = typ.removeprefix("const") + schema_args.append(typ.strip()) + compute_schema = ["std::vector schema = {"] + for schema_arg in schema_args: + compute_schema.append( + f" torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type()," + ) + compute_schema.append("};") + + return template.substitute( + unpacks="\n".join(unpack), + op=info.op, + compute_schema="\n".join(compute_schema), + apply_functional_args=apply_functional_args, + apply_functional_args_signature=apply_functional_args_signature, + compute_needs_input_grad=compute_needs_input_grad, + num_inputs=len(input_name_to_idx), + unpack_ivalues="\n".join(unpack_ivalues), + compute_index_ranges=compute_index_ranges, + saved_variables=saved_variables, + release_variables=release_variables, + saved_list_sizes=saved_list_sizes, + asserts=asserts, + thread_lock=thread_lock, + will_release_variables=will_release_variables, + body=body, + superclass=superclass, + all_getter_definitions=all_getter_definitions, + all_getsetdef_structs=all_getsetdef_structs, + compiled_args=compiled_args, + apply_with_saved_before=apply_with_saved_before, + apply_with_saved_after=apply_with_saved_after, + get_packed_args=get_packed_args, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb3429c39276ec2ad62ff111e7226512b31596f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py @@ -0,0 +1,673 @@ +# Generates ADInplaceOrViewType.h/cpp +# +# NOTE: If any changes are being made to the ADInplaceOrView codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimic this codegen, so we should keep the two in sync. + +from __future__ import annotations + +from torchgen.api import cpp +from torchgen.api.autograd import ( + dispatch_strategy, + gen_differentiable_outputs, + NativeFunctionWithDifferentiabilityInfo, +) +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + DispatcherSignature, + intArrayRefT, + longT, + OptionalCType, + symIntArrayRefT, + SymIntT, + tensorT, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import ( + NativeFunction, + SchemaKind, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.utils import FileManager + +from .context import with_native_function_with_differentiability_info +from .gen_trace_type import ( + get_return_value, + MANUAL_AUTOGRAD, + tie_return_values, + type_wrapper_name, +) + + +# See NOTE [ Autograd View Variables ] in variable.h for details. +# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT, +# you **MUST** also update the public list of view ops accordingly in +# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public, +# e.g alias & sparse_coo_tensor_with_dims_and_tensors. +# +# A map: function name => name of the argument that all outputs are view of + +VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [ + "view_as_complex", + "view_as_real", + "_conj", + "_neg_view", + "_nested_get_values", + "_nested_view_from_buffer", + "_nested_view_from_jagged", +] + +VIEW_FUNCTIONS = { + "numpy_T": "self", + "alias": "self", + "as_strided": "self", + "diagonal": "self", + "expand": "self", + "permute": "self", + "select": "self", + "slice": "self", + "slice_inverse": "self", + "split": "self", + "split_with_sizes": "self", + "squeeze": "self", + "t": "self", + "transpose": "self", + "unfold": "self", + "unsqueeze": "self", + "flatten": "self", + "view": "self", + "unbind": "self", + "_indices": "self", + "_values": "self", + "indices": "self", + "values": "self", + "crow_indices": "self", + "col_indices": "self", + "ccol_indices": "self", + "row_indices": "self", + # sparse_coo ctor output should really be views of both indices and values, + # but we only supports making as view of a single variable, and indices is + # discrete anyways. + # FIXME: clone indices on construction. + "sparse_coo_tensor_with_dims_and_tensors": "values", + "_reshape_alias": "self", + "_test_autograd_multiple_dispatch_view": "self", +} + +for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE: + VIEW_FUNCTIONS[key] = "self" + +# note: some VIEW_FUNCTIONS are just compositions of the view functions above +# this list contains both the root view functions and any that are purely composed +# of viewing functions, and is used by the JIT to determine when an operator +# may return a view of its inputs; however they may sometimes return a copy. +# (e.g. `contiguous`) +RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union( + { + "chunk", + "detach", + "contiguous", + "reshape", + "reshape_as", + "expand_as", + "view_as", + "real", + "imag", + "narrow", + "movedim", + "tensor_split", + "swapdims", + "swapaxes", + "mT", + "mH", + "adjoint", + "matrix_H", + } +) + +# These are the functions we consider views for the purposes of validating +# StorageImpl and TensorImpl in gen_variable_type. +# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a +# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view +# See NOTE [Unsafe View] for more info. +ALL_VIEW_FUNCTIONS = { + **VIEW_FUNCTIONS, + "_unsafe_view": "self", +} + +ARRAYREF_TO_VEC = CodeTemplate( + """\ +auto ${vec} = ${arg}.vec(); +""" +) + +OPTIONAL_TO_VAL = CodeTemplate( + """\ +auto ${val} = ${arg}.value_or(${default}); +""" +) + +CALL_DISPATCH = CodeTemplate( + """\ +at::_ops::${unambiguous_name}::call(${unpacked_args})""" +) + +REVERSE_VIEW_DISPATCH = CodeTemplate( + """\ +${reverse_name}(${unpacked_args})""" +) + +MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate( + """\ +for (auto ${view_idx} : c10::irange(${var}.size())) { + ${body} +} +""" +) + +SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate( + """\ +std::unique_ptr func(nullptr); +std::function rev_func=nullptr; +if (${is_view_with_metadata_change} || + !self.unsafeGetTensorImpl()->support_as_strided() || + self.unsafeGetTensorImpl()->is_python_dispatch() || + c10::AutogradState::get_tls_state().get_view_replay_enabled()) { + ${replay_view_func} + ${reverse_replay_view_func} +} +""" +) + +REPLAY_VIEW_FUNC = CodeTemplate( + """\ +func = std::make_unique<${view_func_name}>(${view_func_args}); +""" +) + +REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate( + """\ +rev_func = [=](const at::Tensor& ${input_view}) { + return ${reverse_replay_view_call}; +}; +""" +) + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${unqual_operator_name_with_overload}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + +AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate( + """\ +m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback()); +""" +) + +INPLACE_REDISPATCH = CodeTemplate( + """\ +{ + at::AutoDispatchBelowADInplaceOrView guard; + at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); +} +""" +) + +ASSIGN_RETURN_VALUE = CodeTemplate( + """\ +${return_values} = ${rhs_value}; +""" +) + +VIEW_REDISPATCH = CodeTemplate( + """\ +${assign_return_values} ([&]() { + at::AutoDispatchBelowADInplaceOrView guard; + return at::_ops::${unambiguous_name}::redispatch(${unpacked_args}); +})(); +""" +) + +TMP_VAR = "_tmp" + + +# FIXME: Ideally these functions should be methods on Type class, but we have a +# comment in codegen/model.py there saying these concepts are not well defined. +# Thus we put a version that commonly used by autograd codegen here. +def is_tensor_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is None + + +def is_tensor_list_type(t: Type) -> bool: + # TODO: Should handle optional here? + return t.is_tensor_like() and t.is_list_like() is not None + + +UNPACK_TENSOR = CodeTemplate( + """\ +auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});""" +) + + +def unpacked_name(arg_name: str) -> str: + return arg_name + "_" + + +# e.g. select.int -> select_copy_int_inverse() +def inverse_view_name(f: NativeFunction) -> str: + copy_variant = f"{f.root_name}_copy" + overload = f"{f.func.name.overload_name}" + if overload != "": + overload = "_" + overload + return f"{copy_variant}{overload}_inverse" + + +def extract_bindings(f: NativeFunction) -> list[Binding]: + return [ + r + for a in f.func.schema_order_arguments() + for r in cpp.argument( + a, + method=False, + symint=True, + cpp_no_default_args=set(), + faithful=False, + has_tensor_options=False, + ) + ] + + +@with_native_function +def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]: + body: list[str] = [] + unpacked_bindings: list[Binding] = [] + + for i, binding in enumerate(extract_bindings(f)): + assert not isinstance(binding.argument, SelfArgument) + if isinstance(binding.argument, TensorOptionsArguments): + raise RuntimeError("VariableKernel shouldn't take TensorOptions") + + is_nullable = binding.argument.type.is_nullable() + if not binding.argument.type.is_tensor_like() or is_nullable: + unpacked_bindings.append(binding) + continue + + is_tensor_list = is_tensor_list_type(binding.argument.type) + ref = (not is_nullable) and not is_tensor_list + suffix = "_opt" if is_nullable and not is_tensor_list else "" + body.append( + UNPACK_TENSOR.substitute( + arg_name=binding.name, + arg_pos=i, + suffix=suffix, + ref="&" if ref else "", + ) + ) + unpacked_bindings.append( + Binding( + name=unpacked_name(binding.name), + nctype=binding.nctype, + argument=binding.argument, + default=binding.default, + ) + ) + + return body, unpacked_bindings + + +def get_base_name(f: NativeFunction) -> str: + return f.func.name.name.base # TODO: should be str(f.func.name.name)? + + +def get_view_info(f: NativeFunction) -> str | None: + base_name = get_base_name(f) + view_info = VIEW_FUNCTIONS.get(base_name) + if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT: + view_info = "self" + return view_info + + +def emit_view_func( + f: NativeFunction, bindings: list[Binding], view_idx: str | None = None +) -> str: + """Generate an additional lambda function to recover views in backward when as_strided is not supported. + See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details. + """ + # TODO: Clean this logic up if we get rid of reverse view funcs or reify them. + input_base = "input_base" + replay_view_func = "" + updated_args: list[str] = [] + known_view_arg_simple_types: list[CType] = [ + BaseCType(longT), + OptionalCType(BaseCType(longT)), + BaseCType(SymIntT), + OptionalCType(BaseCType(SymIntT)), + BaseCType(boolT), + BaseCType(intArrayRefT), + BaseCType(symIntArrayRefT), + ConstRefCType(BaseCType(tensorT)), + ConstRefCType(OptionalCType(BaseCType(tensorT))), + ] + for binding in bindings: + arg, arg_type = binding.name, binding.nctype.type + if arg == "self": + updated_args.append(input_base) + continue + if arg_type not in known_view_arg_simple_types: + known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types]) + raise TypeError( + f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: " + f"{known_types_str}. Please update the list or materialize it so that it can be closed " + "over by value, also add a test in pytorch/xla/test/test_operations.py where this code " + "is exercised." + ) + if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType( + symIntArrayRefT + ): + # It's not safe to close over IntArrayRef by value, since this is a + # reference type, so materialize a vector to close over by value + arg_vec = arg + "_vec" + replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) + updated_args.append(arg_vec) + elif arg_type == OptionalCType(BaseCType(longT)): + # Materialize int64_t? to int64_t + arg_value = arg + "_val" + replay_view_func += OPTIONAL_TO_VAL.substitute( + arg=arg, val=arg_value, default="0" + ) + updated_args.append(arg_value) + elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType( + OptionalCType(BaseCType(tensorT)) + ): + # NB: Closing over a tensor. If a user modifies this tensor, this will be silently + # incorrect. The proper thing to do is to store the version counter and copy on write. + updated_args.append(arg) + else: + updated_args.append(arg) + + from .gen_view_funcs import view_func_name + + view_func_args = [b.name for b in bindings if b.name != "self"] + if view_idx is not None: + view_func_args.append(f"{view_idx}") + replay_view_func += REPLAY_VIEW_FUNC.substitute( + view_func_name=view_func_name(f, include_namespace=True), + view_func_args=view_func_args, + ) + + input_view = "input_view" + reverse_unpacked_args = [ + "self", + f"{input_view}", + # inverse_return_mode= + "at::functionalization::InverseReturnMode::AlwaysView", + *(() if view_idx is None else (f"{view_idx}",)), + # skip input_base arg + *updated_args[1:], + ] + + from torchgen.api.functionalization import reverse_name + + reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute( + reverse_name=reverse_name(f, include_namespace=True), + unpacked_args=reverse_unpacked_args, + ) + reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute( + input_view=input_view, reverse_replay_view_call=reverse_replay_view_call + ) + + is_view_with_metadata_change = ( + "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false" + ) + + return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( + is_view_with_metadata_change=is_view_with_metadata_change, + replay_view_func=replay_view_func, + reverse_replay_view_func=reverse_replay_view_func, + ) + + +def emit_view_body( + fn: NativeFunctionWithDifferentiabilityInfo, var: str +) -> tuple[str, str]: + # See NOTE [ Autograd View Variables ] in variable.h for details. + f = fn.func + base_name = get_base_name(f) + view_info = get_view_info(f) + call = "" + differentiable_outputs = gen_differentiable_outputs(fn) + differentiable_output_vars = {r.name for r in differentiable_outputs} + if not isinstance(view_info, str): + raise TypeError( + f"The view info should be a string for {base_name}, but it is: {view_info}" + ) + if len(differentiable_output_vars) == 0: + # no output is differentiable (.indices() for SparseTensors for example) + rhs_value = ( + f"as_view({view_info}, {var}, " + f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)" + ) + elif len(differentiable_output_vars) == 1: + # Single differentiable output (Tensor or Tensor[]) + return_info = differentiable_outputs[0] + # We only support simple Tensor or a TensorList for functions that return views + if not is_tensor_type(return_info.type) and not is_tensor_list_type( + return_info.type + ): + raise RuntimeError( + f"{base_name} that return differentiable views can only return Tensor or Tensor[]" + ) + + # See Note [ View + Inplace detection] + def get_creation_meta_in_mode(original: str) -> str: + creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)" + return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}" + + # Only allow rebasing of the history if we return a single Tensor + # If we are in a no grad block, raise a warning + # See NOTE [ View + Inplace detection ] for more details about this logic + if is_tensor_list_type(return_info.type): + creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE") + view_idx = "view_idx" + view_func = emit_view_func( + f, extract_bindings(f), view_idx=view_idx + ).strip() + as_view_call = ( + f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], " + "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, " + "/* view_func */ std::move(func), /* rev_view_func */ rev_func, " + f"/* creation_meta */ {creation_meta});" + ) + call += MULTI_OUTPUT_VIEW_ITERATION.substitute( + var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}" + ) + rhs_value = f"std::move({var})" + else: + call += emit_view_func(f, extract_bindings(f), view_idx=None) + creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT") + rhs_value = ( + f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, " + "/* is_fw_differentiable */ true, " + f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})" + ) + else: + # This could be supported but we don't need it at the moment, so keeping things simple. + raise RuntimeError( + "Function that return multiple differentiable output " + "when at least one of them is view is not supported." + ) + return call, rhs_value + + +def modifies_arguments(f: NativeFunction) -> bool: + return f.func.kind() in [SchemaKind.inplace, SchemaKind.out] + + +@with_native_function_with_differentiability_info +def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]: + f = fn.func + inplace_view_body: list[str] = [] + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal. + if modifies_arguments(f): # inplace op + inplace_view_body.append( + INPLACE_REDISPATCH.substitute( + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + for r in cpp.return_names(f): + inplace_view_body.append(f"increment_version({r});") + else: + assert get_view_info(f) is not None + inplace_view_body.append( + VIEW_REDISPATCH.substitute( + assign_return_values="auto " + TMP_VAR + " = ", + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + call, rhs_value = emit_view_body(fn, TMP_VAR) + inplace_view_body.append(call) + assert rhs_value is not None + inplace_view_body.append( + ASSIGN_RETURN_VALUE.substitute( + return_values=tie_return_values(f), rhs_value=rhs_value + ) + ) + if f.func.returns: + inplace_view_body.append(f"return {get_return_value(f)};") + return inplace_view_body + + +@with_native_function +def gen_formals(f: NativeFunction) -> str: + return ", ".join( + # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f"{cpp.argument_type(a, binds='__placeholder__', symint=True).cpp_type()} {a.name}" + for a in f.func.schema_order_arguments() + ] + ) + + +@with_native_function_with_differentiability_info +def inplace_or_view_method_definition( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> str | None: + f = fn.func + if get_view_info(f) is None and ( + # For functions that modify their inputs but don't return them, + # we can't give them autograd support. + # See https://github.com/pytorch/pytorch/issues/53796 + not modifies_arguments(f) or len(f.func.returns) == 0 + ): + return None + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=gen_formals(f), + type_definition_body=emit_inplace_or_view_body(fn), + ) + + +@with_native_function_with_differentiability_info +def inplace_or_view_method_registration( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> str | None: + f = fn.func + if get_view_info(f) is None and ( + not modifies_arguments(f) or len(f.func.returns) == 0 + ): + return None + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="ADInplaceOrView", + ) + + +def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool: + f = fn.func + name = cpp.name(f.func) + return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived" + + +def gen_inplace_or_view_type_env( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> dict[str, list[str]]: + definition = inplace_or_view_method_definition(fn) + registration = inplace_or_view_method_registration(fn) + + return { + "ops_headers": ( + [f"#include "] + if definition is not None + else [] + ), + "inplace_or_view_method_definitions": [definition] + if definition is not None + else [], + "inplace_or_view_wrapper_registrations": [registration] + if registration is not None + else [], + } + + +def gen_inplace_or_view_type( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "ADInplaceOrViewType.cpp", + [fn for fn in fns_with_infos if use_derived(fn)], + key_fn=lambda fn: fn.func.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp", + }, + env_callable=gen_inplace_or_view_type_env, + num_shards=2, + sharded_keys={ + "ops_headers", + "inplace_or_view_method_definitions", + "inplace_or_view_wrapper_registrations", + }, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_python_functions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_python_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..af25d55ef38d87fc0d9398437f116f234634932d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_python_functions.py @@ -0,0 +1,1405 @@ +# Generates Python bindings for ATen functions +# +# The bindings are generated as methods on python_variable or functions on the +# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse +# or torch._C._special objects. +# + +# Code tries to stick to the following rules: +# +# - templates should be colocated with the functions that use them. +# no templates are currently shared between functions, but if that +# happens, maybe put the template with the first one +# +# - don't use environment dictionaries when calling template.substitute(). +# pass named arguments directly for everything, otherwise it's much too +# hard to track what's actually being used and by who +# +# - colocate any new hacks/adjustments with existing ones of the same kind. +# ideally in a data structure rather than code if possible. See e.g. +# SCHEMA_DEFAULT_CONVERSION_HACKS, etc. +# +# - similarly, conversions from one format to another should ideally happen +# all at once in a single place. +# +# - no nontrivial nested functions. couple-liners are ok but please no more. +# especially avoid functions that read/write outer variables defined far away. +# +# - raise RuntimeError instead of asserting, and put as much +# information as is available into the message. I.e. no need to +# plumb in new params whose only purpose is to fill out an error +# message, but use what's there +# + +from __future__ import annotations + +import itertools +import re +from collections import defaultdict +from typing import TYPE_CHECKING + +import yaml + +from torchgen.api import cpp +from torchgen.api.python import ( + arg_parser_output_exprs, + cpp_dispatch_exprs, + cpp_dispatch_target, + dispatch_lambda_args, + dispatch_lambda_exprs, + dispatch_lambda_return_str, + has_tensor_options, + PythonSignature, + PythonSignatureDeprecated, + PythonSignatureGroup, + PythonSignatureNativeFunctionPair, + signature, + signature_from_schema, + structseq_fieldnames, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml +from torchgen.model import ( + Argument, + BaseOperatorName, + FunctionSchema, + NativeFunction, + SchemaKind, + Type, + Variant, +) +from torchgen.utils import FileManager, split_name_params +from torchgen.yaml_utils import YamlLoader + +from .gen_inplace_or_view_type import is_tensor_list_type +from .gen_trace_type import should_trace + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Sequence + + +# +# declarations blocklist +# We skip codegen for these functions, for various reasons. +# Future PRs will categorize this list and eliminate or hoist +# them out of eager-only codegen. +# See https://github.com/pytorch/pytorch/issues/30788 +# + +# These functions require manual Python bindings or are not exposed to Python +_SKIP_PYTHON_BINDINGS = [ + "alias", + "contiguous", + "is_cuda", + "is_sparse", + "is_sparse_csr", + "size", + "stride", + "sym_is_contiguous", + "sym_size", + "sym_stride", + "sym_storage_offset", + "sym_numel", + ".*_backward", + ".*_backward_(out|input|weight|bias)", + ".*_forward", + ".*_forward_out", + ".*_jvp", + "_unsafe_view", + "tensor", + "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*", + "_range.*", + "_sparse_add_out", + "_sparse_div.*", + "_sparse_mul.*", + "_sparse_sub.*", + "_sparse_dense_add_out", + "index", + "index_out", + "unique_dim_consecutive", + "_cumsum.*", + "_cumprod.*", + "_sum.*", + "_prod.*", + "_th_.*", + "_thnn_.*", + "range.*", + "_solve.*", + "_inverse.*", + "_cholesky.*", + "_triangular_solve.*", + "_qr.*", + "_svd.*", + "slice", + "item", + "_local_scalar_dense", + "to", + "_to_copy", + "_to_copy_out", + "_reshape_copy", + "_reshape_copy_out", + "copy_sparse_to_sparse_", + "copy_", + "_foreach_copy", + "numpy_T", + "matrix_H", + "mT", + "mH", # these need to be an attributes in Python, not functions + "nonzero(_(out|numpy))?", + "set_data", + ".*_overrideable", # overridable functions for backend extension + "data", + "is_leaf", + "output_nr", + "_version", + "requires_grad_", + "retains_grad", + "set_", + "_fw_primal", + "fake_quantize_per_tensor_affine_cachemask", + "fake_quantize_per_channel_affine_cachemask", + "_new_zeros_with_same_feature_meta", + "_has_same_storage_numel", # used for forward AD internals + "_reshape_alias", + "replace_", # only used by the functionalization pass, doesn't need to be exposed to python + "copy", # only used by the functionalization pass + "fill.Tensor", # only used by the functionalization pass + "fill.Scalar", # only used by the functionalization pass + "lift.*", + "normal_functional", # only used by the functionalization pass + "nbytes", + "itemsize", + "_batch_norm_with_update", + "_batch_norm_with_update_out", + "_batch_norm_no_update", +] + +SKIP_PYTHON_BINDINGS = [ + re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS +] + +# These function signatures are not exposed to Python. Note that this signature +# list does not support regex. +SKIP_PYTHON_BINDINGS_SIGNATURES = [ + "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", + "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", + "mul.Scalar(Tensor self, Scalar other) -> Tensor", + "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", + "div.Scalar(Tensor self, Scalar other) -> Tensor", + "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", +] + + +@with_native_function +def should_generate_py_binding(f: NativeFunction) -> bool: + # NativeFunctions that are entirely code-generated should not get python bindings + # because these codegen implementations are often inefficient. A handful of + # view_copy style ops were exposed accidentally when they were handwritten and now + # that we are moving them to codegen for bc reasons we need to keep them exposed in + # python. + if "generated" in f.tags and "view_copy" not in f.tags: + return False + + name = cpp.name(f.func) + for skip_regex in SKIP_PYTHON_BINDINGS: + if skip_regex.match(name): + return False + + signature = str(f.func) + for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: + if pattern == signature: + return False + return True + + +def get_pycname(name: BaseOperatorName) -> str: + return f"THPVariable_{name}" + + +def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool: + return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0 + + +def is_py_variable_method(f: NativeFunction) -> bool: + return f.python_module is None and Variant.method in f.variants + + +def is_py_torch_function(f: NativeFunction) -> bool: + return f.python_module is None and Variant.function in f.variants + + +def is_py_nn_function(f: NativeFunction) -> bool: + return f.python_module == "nn" + + +def is_py_fft_function(f: NativeFunction) -> bool: + return f.python_module == "fft" + + +def is_py_linalg_function(f: NativeFunction) -> bool: + return f.python_module == "linalg" + + +def is_py_nested_function(f: NativeFunction) -> bool: + return f.python_module == "nested" + + +def is_py_sparse_function(f: NativeFunction) -> bool: + return f.python_module == "sparse" + + +def is_py_special_function(f: NativeFunction) -> bool: + return f.python_module == "special" + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Main Function +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def gen( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + deprecated_yaml_path: str, + template_path: str, + *, + symint: bool = True, +) -> None: + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + native_functions = list(filter(should_generate_py_binding, native_functions)) + + methods = load_signatures(native_functions, deprecated_yaml_path, method=True) + create_python_bindings( + fm, + methods, + is_py_variable_method, + None, + "python_variable_methods.cpp", + method=True, + symint=symint, + ) + + # NOTE: num_shards here must be synced with gatherTorchFunctions in + # torch/csrc/autograd/python_torch_functions_manual.cpp + functions = load_signatures(native_functions, deprecated_yaml_path, method=False) + create_python_bindings_sharded( + fm, + functions, + is_py_torch_function, + "torch", + "python_torch_functions.cpp", + method=False, + num_shards=3, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nn_function, + "torch.nn", + "python_nn_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_fft_function, + "torch.fft", + "python_fft_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_linalg_function, + "torch.linalg", + "python_linalg_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_nested_function, + "torch.nested", + "python_nested_functions.cpp", + method=False, + ) + + create_python_bindings( + fm, + functions, + is_py_sparse_function, + "torch.sparse", + "python_sparse_functions.cpp", + method=False, + symint=symint, + ) + + create_python_bindings( + fm, + functions, + is_py_special_function, + "torch.special", + "python_special_functions.cpp", + method=False, + symint=symint, + ) + + # Currently, we only use `functions` to generate `return_types` bindings. + # All methods which return structseq have function variant at this point. + # If any method only operator with structseq is added in the future, + # we will have to address that. + create_python_return_type_bindings( + fm, functions, lambda fn: True, "python_return_types.cpp" + ) + create_python_return_type_bindings_header( + fm, functions, lambda fn: True, "python_return_types.h" + ) + + valid_tags = parse_tags_yaml(tags_yaml_path) + + def gen_tags_enum() -> dict[str, str]: + return { + "enum_of_valid_tags": ( + "".join( + [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)] + ) + ) + } + + fm.write("python_enum_tag.cpp", gen_tags_enum) + + +def group_filter_overloads( + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], +) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]: + grouped: dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] = ( + defaultdict(list) + ) + for pair in pairs: + if pred(pair.function): + grouped[pair.function.func.name.name].append(pair) + return grouped + + +def create_python_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: str | None, + filename: str, + *, + method: bool, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + py_methods: list[str] = [] + ops_headers: list[str] = [] + py_method_defs: list[str] = [] + py_forwards: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + py_methods.append( + method_impl(name, module, overloads, method=method, symint=symint) + ) + py_method_defs.append(method_def(name, module, overloads, method=method)) + py_forwards.extend(forward_decls(name, overloads, method=method)) + ops_headers.append(f"#include ") + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "ops_headers": ops_headers, + "py_forwards": py_forwards, + "py_methods": py_methods, + "py_method_defs": py_method_defs, + }, + ) + + +def create_python_return_type_bindings( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and registration invocations in `python_return_types.cpp`. + """ + py_return_types_definition: list[str] = [] + py_return_types_registrations: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + definitions, registrations = generate_return_type_definition_and_registrations( + overloads + ) + py_return_types_definition.append( + "" if not definitions else "\n".join(definitions) + ) + py_return_types_registrations.append( + "" if not registrations else "\n".join(registrations) + ) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types": py_return_types_definition, + "py_return_types_registrations": py_return_types_registrations, + }, + ) + + +def create_python_return_type_bindings_header( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + filename: str, +) -> None: + """ + Generate function to initialize and return named tuple for native functions + which returns named tuple and relevant entry for the map in `python_return_types.cpp`. + """ + py_return_types_declarations: list[str] = [] + + grouped = group_filter_overloads(pairs, pred) + + for name in sorted(grouped.keys(), key=str): + overloads = grouped[name] + declarations = generate_return_type_declarations(overloads) + py_return_types_declarations.append( + "" if not declarations else "\n".join(declarations) + ) + + fm.write_with_template( + filename, + filename, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + "py_return_types_declarations": py_return_types_declarations, + }, + ) + + +def create_python_bindings_sharded( + fm: FileManager, + pairs: Sequence[PythonSignatureNativeFunctionPair], + pred: Callable[[NativeFunction], bool], + module: str | None, + filename: str, + *, + method: bool, + num_shards: int, + symint: bool = True, +) -> None: + """Generates Python bindings to ATen functions""" + grouped = group_filter_overloads(pairs, pred) + + def key_func( + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]], + ) -> str: + return kv[0].base + + def env_func( + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]], + ) -> dict[str, list[str]]: + name, fn_pairs = kv + return { + "ops_headers": [f"#include "], + "py_forwards": list(forward_decls(name, fn_pairs, method=method)), + "py_methods": [ + method_impl(name, module, fn_pairs, method=method, symint=symint) + ], + "py_method_defs": [method_def(name, module, fn_pairs, method=method)], + } + + fm.write_sharded( + filename, + grouped.items(), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{filename}", + }, + key_fn=key_func, + env_callable=env_func, + num_shards=num_shards, + sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"}, + ) + + +def load_signatures( + native_functions: list[NativeFunction], + deprecated_yaml_path: str, + *, + method: bool, + skip_deprecated: bool = False, + pyi: bool = False, +) -> Sequence[PythonSignatureNativeFunctionPair]: + @with_native_function + def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair: + return PythonSignatureNativeFunctionPair( + signature=signature(f, method=method, pyi=pyi), + function=f, + ) + + pairs = list(map(gen_signature_pairs, native_functions)) + deprecated = load_deprecated_signatures( + pairs, deprecated_yaml_path, method=method, pyi=pyi + ) + return pairs if skip_deprecated else pairs + deprecated + + +def load_deprecated_signatures( + pairs: Sequence[PythonSignatureNativeFunctionPair], + deprecated_yaml_path: str, + *, + method: bool, + pyi: bool, +) -> list[PythonSignatureNativeFunctionPair]: + # The deprecated.yaml doesn't have complete type information, we need + # find and leverage the original ATen signature (to which it delegates + # the call) to generate the full python signature. + # We join the deprecated and the original signatures using type-only form. + + # group the original ATen signatures by name + grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list) + for pair in pairs: + grouped[pair.signature.name].append(pair) + + # find matching original signatures for each deprecated signature + results: list[PythonSignatureNativeFunctionPair] = [] + + with open(deprecated_yaml_path) as f: + deprecated_defs = yaml.load(f, Loader=YamlLoader) + + for deprecated in deprecated_defs: + schema = FunctionSchema.parse(deprecated["name"]) + aten_name, call_args = split_name_params(deprecated["aten"]) + is_out = aten_name.endswith("_out") + if is_out: + aten_name = aten_name.replace("_out", "") + + # HACK: these are fixed constants used to pass the aten function. + # The type must be known ahead of time + known_constants = { + "1": Type.parse("Scalar"), + } + schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} + for name in call_args: + assert name in schema_args_by_name or name in known_constants, ( + f"deprecation definition: Unrecognized value {name}" + ) + + # Map deprecated signature arguments to their aten signature and test + # if the types and alias annotation match. + def is_schema_compatible( + aten_schema: FunctionSchema, + ) -> bool: + arguments: Iterable[Argument] + if is_out: + arguments = itertools.chain( + aten_schema.arguments.out, aten_schema.arguments.flat_non_out + ) + else: + arguments = aten_schema.arguments.flat_all + + for i, arg in enumerate(arguments): + if i < len(call_args): + arg_name = call_args[i] + if arg_name in known_constants: + schema_type = known_constants[arg_name] + schema_annotation = None + else: + schema_arg = schema_args_by_name[arg_name] + schema_type = schema_arg.type + schema_annotation = schema_arg.annotation + + if schema_type != arg.type or schema_annotation != arg.annotation: + return False + else: + if arg.default is None: + return False + + return len(schema.returns) == len(aten_schema.returns) and all( + a == b for a, b in zip(schema.returns, aten_schema.returns) + ) + + any_schema_found = False + for pair in grouped[aten_name]: + if not is_schema_compatible(pair.function.func): + continue + any_schema_found = True + + python_sig = signature_from_schema( + schema, + category_override=pair.function.category_override, + method=method, + pyi=pyi, + ) + + results.append( + PythonSignatureNativeFunctionPair( + signature=PythonSignatureDeprecated( + name=python_sig.name, + input_args=python_sig.input_args, + input_kwargs=python_sig.input_kwargs, + output_args=python_sig.output_args, + tensor_options_args=python_sig.tensor_options_args, + method=python_sig.method, + deprecated_schema=schema, + deprecated_args_exprs=tuple(call_args), + returns=python_sig.returns, + ), + function=pair.function, + ) + ) + assert any_schema_found, ( + f"No native function with name {aten_name} matched signature:\n {str(schema)}" + ) + + return results + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Named Tuple Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +@with_native_function +def gen_structseq_typename_key(f: NativeFunction) -> str: + name = cpp.name(f.func) + fieldnames = structseq_fieldnames(f.func.returns) + return "_".join([name] + fieldnames) + + +def emit_structseq_call( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> tuple[list[str], dict[str, str]]: + """ + Generate block of named tuple type def inits, and add typeref snippets + to declarations that use them + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + typedefs: list[str] = [] # typedef declarations and init code + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + if typename is None: + typename = f"NamedTuple{'' if not typedefs else len(typedefs)}" + typenames[tn_key] = typename + typedefs.append( + f"""\ +static PyTypeObject* {typename} = generated::get_{name}_structseq();""" + ) + + return typedefs, typenames + + +def generate_return_type_definition_and_registrations( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> tuple[list[str], list[str]]: + """ + Generate block of function in `python_return_types.cpp` to initialize + and return named tuple for a native function which returns named tuple + and registration invocations in same file. + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + definitions: list[str] = [] # function definition to register the typedef + registrations: list[str] = [] # register call for the typedef + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames) + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = f"{name}NamedTuple{'' if not definitions else len(definitions)}" + typenames[tn_key] = typename + definitions.append( + f"""\ +PyTypeObject* get_{name}_structseq() {{ + static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; + static PyTypeObject {typename}; + static bool is_initialized = false; + static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; + if (!is_initialized) {{ + PyStructSequence_InitType(&{typename}, &desc); + {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; + is_initialized = true; + }} + return &{typename}; +}} +""" + ) + registrations.append( + f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());' + ) + + return definitions, registrations + + +def generate_return_type_declarations( + overloads: Sequence[PythonSignatureNativeFunctionPair], +) -> list[str]: + """ + Generate block of function declarations in `python_return_types.h` to initialize + and return named tuple for a native function. + """ + typenames: dict[ + str, str + ] = {} # map from unique name + field name lists to typedef name + declarations: list[str] = [] # function declaration to register the typedef + + for overload in overloads: + fieldnames = structseq_fieldnames(overload.function.func.returns) + if not fieldnames: + continue + + name = cpp.name(overload.function.func) # use @with_native_function? + tn_key = gen_structseq_typename_key(overload.function) + typename = typenames.get(tn_key) + + if typename is None: + typename = ( + f"{name}NamedTuple{'' if not declarations else len(declarations)}" + ) + typenames[tn_key] = typename + declarations.append(f"PyTypeObject* get_{name}_structseq();") + + return declarations + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Impl Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + +# python binding for all overloads of a particular function/method +PY_VARIABLE_METHOD_VARARGS = CodeTemplate( + r"""\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + switch (_r.idx) { + ${dispatch} + } + ${method_footer} +} + +""" +) + +# handler for a single parsed signature - may be a single overload or +# a pair of overloads that whose signatures only differ in output params +# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch}) +PY_VARIABLE_CASE = CodeTemplate( + """\ +case ${overload_index}: { + ${body} +} +""" +) + +# python binding for single-overload function/method +PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) +{ + ${method_header} + static PythonArgParser parser({ + ${signatures} + }, /*traceable=*/${traceable}); + + ParsedArgs<${max_args}> parsed_args; + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + +# python binding for a method with no args, shortcuts parsing +PY_VARIABLE_METHOD_NOARGS = CodeTemplate( + """\ +// ${name} +static PyObject * ${pycname}(PyObject* self_, PyObject* args) +{ + ${method_header} + ${check_has_torch_function} + ${dispatch} + ${method_footer} +} + +""" +) + + +def method_impl( + name: BaseOperatorName, + module: str | None, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, + symint: bool = True, +) -> str: + """ + Generate a python binding for all overloads of an op. + """ + pycname = get_pycname(name) + noarg = is_noarg(overloads) + structseq_inits, structseq_typenames = emit_structseq_call(overloads) + + method_header = ["HANDLE_TH_ERRORS"] + method_header += structseq_inits + method_header += ( + ["const Tensor& self = THPVariable_Unpack(self_);"] if method else [] + ) + + method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"] + + traceable = "true" if all(should_trace(o.function) for o in overloads) else "false" + + grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads( + overloads, symint=symint + ) + is_singleton = len(grouped_overloads) == 1 + signatures: list[str] = [] + dispatch: list[str] = [] + for overload_index, overload in enumerate(grouped_overloads): + signature = overload.signature.signature_str(symint=symint) + signatures.append(f"{cpp_string(str(signature))},") + dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint) + dispatch.append( + PY_VARIABLE_CASE.substitute( + overload_index=overload_index, body=dispatch_body + ) + if not is_singleton + else dispatch_body + ) + + if noarg: + template = PY_VARIABLE_METHOD_NOARGS + elif is_singleton: + template = PY_VARIABLE_METHOD_VARARGS_SINGLETON + else: + template = PY_VARIABLE_METHOD_VARARGS + + return template.substitute( + name=name, + pycname=pycname, + method_header=method_header, + max_args=max(o.signature.arguments_count() for o in overloads), + signatures=signatures, + traceable=traceable, + check_has_torch_function=gen_has_torch_function_check( + name=name, + module=module, + noarg=noarg, + method=method, + ), + dispatch=dispatch, + method_footer=method_footer, + self_="self_" if method else "nullptr", + ) + + +def gen_has_torch_function_check( + name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool +) -> str: + if noarg: + if method: + return f"""\ +if(check_has_torch_function(self_)) {{ + return handle_torch_function(self_, "{name}"); +}} +""" + else: + return "" + + self_ = "self_" if method else "nullptr" + namespace = ( + { + "torch": "THPVariableFunctionsModule", + "torch.nn": "THPNNVariableFunctionsModule", + "torch.fft": "THPFFTVariableFunctionsModule", + "torch.linalg": "THPLinalgVariableFunctionsModule", + "torch.nested": "THPNestedVariableFunctionsModule", + "torch.sparse": "THPSparseVariableFunctionsModule", + "torch.special": "THPSpecialVariableFunctionsModule", + }[module] + if module + else "THPVariableClass" + ) + + return f"""\ +if(_r.has_torch_function()) {{ + return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}"); +}} +""" + + +# handler for output/no-output overload pair +PY_VARIABLE_OUT = CodeTemplate( + """\ +if (_r.isNone(${out_idx})) { + ${call_dispatch} +} else { + ${call_dispatch_out} +} +""" +) + + +def emit_dispatch_case( + overload: PythonSignatureGroup, + structseq_typenames: dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single parsed signature. This corresponds to either + a single native function, or a pair that differ only in output params. In the + latter case, a single python signature is used for both and dispatching + switches on the presence/absence of passed output args. + """ + if overload.outplace is not None: + # dispatch output and no-output variants, branch on _r.isNone() + return PY_VARIABLE_OUT.substitute( + out_idx=overload.signature.output_idx(), + call_dispatch=emit_single_dispatch( + overload.signature, overload.base, structseq_typenames, symint=symint + ), + call_dispatch_out=emit_single_dispatch( + overload.signature, + overload.outplace, + structseq_typenames, + symint=symint, + ), + ) + else: + # no-output version only + return emit_single_dispatch( + overload.signature, overload.base, structseq_typenames, symint=symint + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Forward Declarations Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def forward_decls( + name: BaseOperatorName, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> tuple[str, ...]: + if method: + return () + + pycname = get_pycname(name) + if is_noarg(overloads): + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args); +""", + ) + else: + return ( + f"""\ +static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); +""", + ) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Method Def (Binding Table Entry) Codegen +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def method_def( + name: BaseOperatorName, + module: str | None, + overloads: Sequence[PythonSignatureNativeFunctionPair], + *, + method: bool, +) -> str: + """ + Generate method def entry. + """ + pycname = get_pycname(name) + + if name.dunder_method: + # PyMethodDef entry for binary op, throws not implemented error + pycname = f"TypeError_to_NotImplemented_<{pycname}>" + + if is_noarg(overloads): + flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS" + else: + pycname = f"castPyCFunctionWithKeywords({pycname})" + flags = "METH_VARARGS | METH_KEYWORDS" + + if module == "torch": + flags += " | METH_STATIC" + + return f'{{"{name}", {pycname}, {flags}, nullptr}},' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Overload Sorting and Grouping +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def group_overloads( + overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + bases: dict[str, PythonSignatureNativeFunctionPair] = {} + outplaces: dict[str, PythonSignatureNativeFunctionPair] = {} + + # first group by signature ignoring out arguments + for overload in overloads: + sig = overload.signature.signature_str(skip_outputs=True, symint=symint) + if overload.function.func.is_out_fn(): + if sig in outplaces: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {outplaces[sig].function.func}." + ) + outplaces[sig] = overload + else: + if sig in bases: + raise RuntimeError( + f"Found duplicated function definition:\n- {overload.function.func}.\n" + f"Existing definition:\n- {bases[sig].function.func}." + ) + bases[sig] = overload + + for sig, out in outplaces.items(): + if sig not in bases: + candidates: list[str] = [] + for overload in overloads: + if ( + str(overload.function.func.name.name) + == str(out.function.func.name.name) + and not overload.function.func.is_out_fn() + and not overload.signature.deprecated + ): + candidates.append( + overload.signature.signature_str( + skip_outputs=True, symint=symint + ) + ) + out_sig = out.signature.signature_str(symint=symint) + raise RuntimeError( + f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. " + f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema " + "correctly in native_functions.yaml. We discovered the following candidate(s): \n" + + "\n".join(f"- {candidate}" for candidate in candidates) + ) + + grouped = [ + PythonSignatureGroup.from_pairs( + functional=base, + out=outplaces.get(sig), + ) + for sig, base in bases.items() + ] + return sort_overloads(grouped, symint=symint) + + +# This function declares a partial order on declarations, and sorts them according +# to its linear extension. This is necessary, because there's some ambiguity in the +# choice of overload, and we want a different order. +# +# See Note[Order of overloads matters] +# +# A few examples of ambiguous python signature pairs. +# +# All parameters have the same type, except one taking Tensor the other taking +# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor +# object can be accepted as Scalar type parameter (see python_arg_parser.cpp). +# Therefore, same input arguments might be accepted by either python signature. +# We want to always parse the one taking Tensor first. +# +# bitwise_and(Tensor input, Tensor other, *, Tensor out=None) +# bitwise_and(Tensor input, Scalar other, *, Tensor out=None) +# +# If they have different number of parameters then they are not ambiguous - but +# the difference on output param can be ignored as it's optional. +# +# multiply(Tensor input, Tensor other, *, Tensor out=None) +# multiply(Tensor input, Scalar other) +# +# Both positional args and keyword-only args are considered together. +# +# subtract(Tensor other, *, Scalar alpha=1) +# subtract(Scalar other, Scalar alpha=1) +# +# A few ambiguous cases which it does NOT handle yet. +# +# If there is any difference in other parameters besides the Tensor/Scalar +# difference, then they are not considered ambiguous by this method anymore. +# However, the difference could be too trivial to disambiguate. +# +# foo(Tensor input, Scalar other, Scalar bar) +# foo(Tensor input, Tensor other, double bar) +# +# If they are taking different number of parameters then they are not considered +# ambiguous anymore, even if the difference is only on optional kwargs. +# +# foo(Scalar other, Scalar alpha=1) +# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1) +# + + +def sort_overloads( + grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True +) -> Sequence[PythonSignatureGroup]: + # NB: Smaller here means lower priority + + def is_arg_smaller(t1: Type, t2: Type) -> bool: + return ( + str(t1) == "Scalar" + and str(t2) == "Tensor" + or str(t1) == "Scalar?" + and str(t2) == "Tensor?" + or "Dimname" in str(t1) + and "Dimname" not in str(t2) + or + # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been + # discussed why it is important to prioritize int/int? over int[] + str(t1) == "int[]" + and (str(t2) == "int" or str(t2) == "int?") + or + # TensorList currently throws an error during argument parsing, that's why it needs to be + # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087 + str(t1) == "Tensor[]" + and str(t2).find("[]") != -1 + or + # Prioritize IntArrayRef overload over SymIntArrayRef + str(t1) == "SymInt[]" + and str(t2) == "int[]" + or + # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly + # converted to either int or SymInt. Prioritize the Tensor overload since it otherwise gets shadowed. + (str(t1) == "SymInt" or str(t1) == "int") + and str(t2) == "Tensor" + ) + + def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool: + """Returns True if s1 < s2 in the partial order.""" + args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True) + if len(args1) != len(args2): + return False + # TODO: should use some canonical form instead of 'str(arg.type)' - see comments + # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which + # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'. + equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2)) + smaller_or_equal = all( + str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type) + for arg1, arg2 in zip(args1, args2) + ) + return smaller_or_equal and not equal + + # First sort by signature + grouped_overloads = sorted( + grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint) + ) + + # Construct the relation graph + larger_than: dict[int, set[int]] = defaultdict(set) + for i1, overload1 in enumerate(grouped_overloads): + for i2, overload2 in enumerate(grouped_overloads): + if is_smaller(overload1.signature, overload2.signature): + larger_than[i1].add(i2) + + if not larger_than: + return list(grouped_overloads) + + # Use a topological sort to sort overloads according to the partial order. + N = len(grouped_overloads) + sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N))) + + for idx in range(N): + # The size of sorted_ids will grow to N eventually. + i = sorted_ids[idx] + for j in sorted(larger_than.keys()): + larger = larger_than[j] + larger.discard(i) + if not larger: + del larger_than[j] + sorted_ids.append(j) + + return [grouped_overloads[x] for x in sorted_ids] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# +# Codegen API Integration +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # + + +def emit_single_dispatch( + ps: PythonSignature, + f: NativeFunction, + structseq_typenames: dict[str, str], + *, + symint: bool = True, +) -> str: + """ + Emit dispatch code for a single native function. + """ + + @with_native_function + def go(f: NativeFunction) -> str: + # header comments + if isinstance(ps, PythonSignatureDeprecated): + schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" + else: + schema_comment = f"// aten::{f.func}" + + # dispatch lambda signature + name = cpp.name(f.func) + lambda_formals = ", ".join( + f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint) + ) + lambda_return = dispatch_lambda_return_str(f) + + # dispatch lambda body + dispatch_callee = cpp_dispatch_target(f) + dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) + + # from arg parser outputs to dispatch lambda arguments + parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) + lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint) + inits = "\n".join(lambda_arg_exprs.inits) + lambda_args = ", ".join(lambda_arg_exprs.exprs) + + # scatter fields + # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky + # solution for enabling the 'requires_grad' argument for tensor methods + # new_full, new_empty, and new_zeros. A much better but more difficult to + # implement solution involves refactoring according to Ed's description here: + # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 + need_set_requires_grad = ps.tensor_options_args and ( + not has_tensor_options(f) + or (ps.method and ("requires_grad" in parser_outputs)) + ) + set_requires_grad = ( + f".set_requires_grad({parser_outputs['requires_grad'].expr})" + if need_set_requires_grad + else "" + ) + + if lambda_return == "void": + # Make in-place foreach return `self` at python-binding level. + # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954 + self_arg = f.func.arguments.self_arg + return_stmt: str + if ( + str(f.func.name).startswith("_foreach_") + and f.func.kind() == SchemaKind.inplace + ): + # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place + # variant and it unlikely to have it in the future. Thus it's safe to have the following assert. + assert self_arg is not None and is_tensor_list_type( + self_arg.argument.type + ) + return_stmt = """PyObject* self_tensorlist = _r.args[0]; +Py_INCREF(self_tensorlist); +return self_tensorlist; +""" + else: + return_stmt = "Py_RETURN_NONE;" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + {dispatch_callee}({dispatch_args}); +}}; +dispatch_{name}({lambda_args}){set_requires_grad}; +{return_stmt} +""" + else: + typename = structseq_typenames.get(gen_structseq_typename_key(f)) + structseq_typeref = f"{typename}, " if typename is not None else "" + return f"""\ +{schema_comment} +{inits} +auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ + pybind11::gil_scoped_release no_gil; + return {dispatch_callee}({dispatch_args}); +}}; +return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); +""" + + return go(f) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_trace_type.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_trace_type.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4ecbd14f514851610c27a4d810b88db934d4df --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_trace_type.py @@ -0,0 +1,540 @@ +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING + +from torchgen.api import cpp +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import with_native_function +from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments +from torchgen.utils import FileManager + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# Note [Manual Backend kernels] +# For these ops, we want to manually register to dispatch key Backend and +# skip codegen-ed registration to all keys before Backend. +# For codegen this means: +# - op set below must match ops with manual_kernel_registration=True in native_functions.yaml +# where we skip codegen backend kernels +# - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration +# - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration +# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_BACKEND = { + "options", + "data", + "set_data", + "is_leaf", + "output_nr", + "_version", + "retain_grad", + "_backward", + "requires_grad_", +} + +# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys. +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD_AND_TRACER = { + "resize_", + "resize_as_", + "detach", + "detach_", + "copy_", + "_fw_primal", + "_make_dual", +} + +# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops: +# union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER) +# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp +MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER + +# These functions we don't want to record for tracing, because we always want +# to trace their constituent parts. This is a temporary hack in lieue +# of proper scopes, where subsequent compilation passes can ask for the unfolding +# on demand. Only concrete ATen methods can be disabled this way; it will have +# NO EFFECT otherwise. +DONT_RECORD_TRACE = { + "convolution", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "lstm_cell", + "gru_cell", + "rnn_tanh_cell", + "rnn_relu_cell", + # FIXME: figure out a better way when we support sparse tensors in jit + "_coalesced", +} + + +def should_trace(f: NativeFunction) -> bool: + # Operations involving Storage or Type are not traceable at the moment + if any( + str(arg.type) in {"Storage", "Type"} for arg in f.func.schema_order_arguments() + ): + return False + # We can't trace functions which don't have any Tensor or TensorList returns + if not any(r.type.is_tensor_like() for r in f.func.returns): + return False + return f.func.name.name.base not in DONT_RECORD_TRACE + + +SELECT = CodeTemplate( + """\ + +if (${cond}) { + ${true} +} else { + ${false} +} +""" +) + +OP_NAME = CodeTemplate( + """\ +op_name = c10::Symbol::fromQualString("aten::${trace_name}"); +""" +) + +# These functions have their names recorded under trace renamed, +RENAME_TRACE = { + "zero": "zeros_like", # replacing aten::zero_ with aten::zeros_like + "fill": "full_like", # replacing aten::fill_ with aten::full_like +} + + +def format_trace_op_name(f: NativeFunction) -> str: + # TODO: byte-for-byte compatible with old codegen behavior - should clean up + if ( + f.func.kind() in (SchemaKind.functional, SchemaKind.out) + or f.func.name.name.dunder_method + ): + # special case for *_out functions: the in-place and out-of-place ops + # are overloaded with the same name in the JIT + trace_name = str(f.func.name.name) + trace_name = RENAME_TRACE.get(trace_name, trace_name) + return OP_NAME.substitute(trace_name=trace_name) + + # otherwise, this is an in-place op and we need to emit both in- and + # out-of-place versions + outplace_trace_name = f.func.name.name.base + inplace_trace_name = cpp.name(f.func) + outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) + inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) + + return SELECT.substitute( + cond="tracer_state->force_outplace", + true=OP_NAME.substitute(trace_name=outplace_trace_name), + false=OP_NAME.substitute(trace_name=inplace_trace_name), + ) + + +ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""") + + +def format_trace_inputs(f: NativeFunction) -> str: + def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]: + if isinstance(arg, TensorOptionsArguments): + name = "options" + return [ + ADD_TRACE_INPUT.substitute( + name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())" + ), + ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.device()"), + ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"), + ] + else: + name = arg.name + if str(arg.type) == "Tensor?[]": + return [f'jit::tracer::addInputs(node, "{name}", {name});'] + else: + return [ADD_TRACE_INPUT.substitute(name=name, input=name)] + + args: list[Argument | TensorOptionsArguments] = list( + f.func.schema_order_arguments() + ) + + if f.func.is_out_fn(): + # *_out functions take the result as a separate argument, but we don't want to + # trace that argument directly. Instead, we trace its TensorOptions. + # So first, we need to remove the out argument from the list of arguments to trace. + num_out_args = len(f.func.arguments.out) + args = args[:-num_out_args] + + trace_inputs = itertools.chain.from_iterable( + dispatch_trace_input(arg) for arg in args + ) + + if f.func.is_out_fn(): + # for *_out functions, handle the result argument differently for inplace/outplace. + # For inplace: just add the input to the end to confirm with the JIT schema + inplace = [ + ADD_TRACE_INPUT.substitute( + name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name + ) + # pyrefly: ignore [unbound-name] + for i in range(num_out_args) + ] + + # for outplace: do nothing, except if the function is a factory. + # Factories are a bit special because their out-of-place overloads + # take an extra TensorOptions argument, which is missing in the _out function + has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns) + has_tensor_input_arg = any( + a.type.is_tensor_like() for a in f.func.arguments.flat_non_out + ) + is_factory_method = f.category_override == "factory" or ( + has_tensor_return and not has_tensor_input_arg + ) + + # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method` + # flag for the whole family of ops with the same basename if any of them is a + # factory method. For most cases the whole family of ops are indeed all factory + # method - 'normal' is the only exception. So we handle it specially here to avoid + # cloning the old logic. + if f.func.name.name.base == "normal": + is_factory_method = True + + if is_factory_method: + outplace = [ + ADD_TRACE_INPUT.substitute( + name="out", + input="c10::optTypeMetaToScalarType(out.options().dtype_opt())", + ), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"), + ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"), + ADD_TRACE_INPUT.substitute( + name="out", input="out.options().pinned_memory()" + ), + ] + else: + outplace = [] + + trace_inputs = itertools.chain( + trace_inputs, + [ + SELECT.substitute( + cond="tracer_state->force_outplace", + true="\n".join(outplace), + false="\n".join(inplace), + ) + ], + ) + + return "\n".join(trace_inputs) + + +# `torch.jit.trace` have undocumented keyword argument `_force_outplace`, +# which force jit to replace functions with outplace variants (for +# example `aten::add_` becomes `aten::add`). +# +# This replacement implemented in-place with minimum modifications of +# arguments stack (as it assumes that outplace call has the same arguments +# as inplace version). +# +# However there are no such substitutions available for `aten::fill_` +# and `aten::zero_` operators, as we never implemented `aten::fill` +# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with +# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`. +# +# But as they potentially can have different arguments, we also have +# to hack into the stack and add missing ones. +# +# A possible alternative would be: +# +# - Add `aten::fill` and `aten::zero` +# +# - Or keep `aten::zeros_like` arguments aligned with `aten::zero_` +# arguments (inside of the `native_functions.yaml`) +RENAME_TRACE_ADD_ARGS = { + "fill": """\ + jit::tracer::addInputs(node, "options", ::std::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); + ::std::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", + "zero": """\ + jit::tracer::addInputs(node, "options", ::std::optional()); + jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt)); + jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt)); + ::std::optional memory_format = c10::MemoryFormat::Preserve; + jit::tracer::addInputs(node, "memory_format", memory_format); +""", +} + +INPLACE_GUARD = CodeTemplate( + """\ +jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input}); +""" +) + +PRE_RECORD_TRACE = CodeTemplate( + """\ +torch::jit::Node* node = nullptr; +std::shared_ptr tracer_state; +if (jit::tracer::isTracing()) { + tracer_state = jit::tracer::getTracingState(); + at::Symbol op_name; + ${set_op_name} + node = tracer_state->createNode(op_name, /*num_outputs=*/0); + jit::tracer::recordSourceLocation(node); + ${add_trace_inputs} + tracer_state->insertNode(node); + ${inplace_guard} + jit::tracer::setTracingState(nullptr); +} +""" +) + + +def format_prerecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # TODO: clean up old codegen behavior + is_inplace = ( + f.func.kind() in (SchemaKind.inplace, SchemaKind.out) + and not f.func.name.name.dunder_method + ) + add_args = ( + RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else "" + ) + additional_inputs = ( + SELECT.substitute( + cond="tracer_state->force_outplace", + true=add_args, + false="", + ) + if add_args + else "" + ) + + return PRE_RECORD_TRACE.substitute( + set_op_name=format_trace_op_name(f), + add_trace_inputs=format_trace_inputs(f) + additional_inputs, + inplace_guard=INPLACE_GUARD.substitute( + name=cpp.name(f.func), + mutable_input=f.func.arguments.out[0].name + if f.func.arguments.out + else "self", + ) + if is_inplace + else "", + ) + + +POST_RECORD_TRACE = CodeTemplate( + """\ +if (tracer_state) { + jit::tracer::setTracingState(std::move(tracer_state)); + ${add_trace_outputs} +} +""" +) + + +def format_postrecord_trace(f: NativeFunction) -> str: + if not should_trace(f): + return "" + + # For outplacing ops, *_out overloads require special handling to move the + # output *argument* to a return value + if f.func.is_out_fn(): + output_names_outplace = [arg.name for arg in f.func.arguments.out] + output_names_inplace = cpp.return_names(f) + + # Code size optimization: the common case is that the return value is + # the same for both variants + if output_names_outplace == output_names_inplace: + outputs = [ + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + selection = SELECT.substitute( + cond="force_outplace", + true="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace + ), + false="\n".join( + f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace + ), + ) + return POST_RECORD_TRACE.substitute(add_trace_outputs=selection) + else: + output_names = cpp.return_names(f) + outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names] + return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs) + + +def tie_return_values(f: NativeFunction) -> str: + if len(f.func.returns) == 1: + return f"auto {f.func.returns[0].name or 'result'}" + names = cpp.return_names(f) + return f"auto [{', '.join(names)}]" + + +def get_return_value(f: NativeFunction) -> str: + names = cpp.return_names(f) + if len(f.func.returns) == 1: + return names[0] + if f.func.kind() == SchemaKind.out: + return f"std::forward_as_tuple({', '.join(names)})" + else: + moved = ", ".join(f"std::move({name})" for name in names) + return f"std::make_tuple({moved})" + + +TRACE_DISPATCH = CodeTemplate( + """\ +${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});""" +) + + +def emit_trace_body(f: NativeFunction) -> list[str]: + trace_body: list[str] = [] + + trace_body.append(format_prerecord_trace(f)) + + dispatcher_sig = DispatcherSignature.from_schema(f.func) + dispatcher_exprs = dispatcher_sig.exprs() + + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)" + redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs]) + + assign_return_values = ( + f"{tie_return_values(f)} = " + if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable] + and f.func.returns + else "" + ) + + # Note that this calls the slow, dispatching variants of manual_cpp_binding ops. + # We could probably work harder to ensure that the fast variants are + # called instead, but the perf benefit would be minimal. + trace_body.append( + TRACE_DISPATCH.substitute( + assign_return_values=assign_return_values, + unambiguous_name=f.func.name.unambiguous_name(), + unpacked_args=redispatch_args, + ) + ) + + trace_body.append(format_postrecord_trace(f)) + if f.func.returns: + trace_body.append(f"return {get_return_value(f)};") + return trace_body + + +METHOD_DEFINITION = CodeTemplate( + """\ +${return_type} ${type_wrapper_name}(${formals}) { + ${type_definition_body} +} +""" +) + + +def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str: + if f.func.name.overload_name: + name = f"{cpp.name(f.func)}_{f.func.name.overload_name}" + else: + name = cpp.name(f.func) + + # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key. + # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated, + # the key argument should not be passed. + # We do not append key if it is Default so that generated functions from + # before per-dispatch-key derivatives were added retain the same names. + if key != "Default": + name = name + f"_{key}" + return name + + +@with_native_function +def method_definition(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + formals = ", ".join( + # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. + # See Note [Plumbing Keys Through The Dispatcher] for details. + ["c10::DispatchKeySet ks"] + + [ + f"{cpp.argument_type(a, binds='__placeholder__', symint=True).cpp_type()} {a.name}" + for a in f.func.schema_order_arguments() + ] + ) + + return METHOD_DEFINITION.substitute( + return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(), + type_wrapper_name=type_wrapper_name(f), + formals=formals, + type_definition_body=emit_trace_body(f), + ) + + +WRAPPER_REGISTRATION = CodeTemplate( + """\ +m.impl("${name}", + TORCH_FN(${class_type}::${type_wrapper_name}) +); +""" +) + + +@with_native_function +def method_registration(f: NativeFunction) -> str: + assert cpp.name(f.func) not in MANUAL_TRACER + + return WRAPPER_REGISTRATION.substitute( + name=f.func.name, + type_wrapper_name=type_wrapper_name(f), + class_type="TraceType", + ) + + +def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]: + return { + "ops_headers": [f"#include "], + "trace_method_definitions": [method_definition(fn)], + "trace_wrapper_registrations": [method_registration(fn)], + } + + +def gen_trace_type( + out: str, native_functions: list[NativeFunction], template_path: str +) -> None: + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_sharded( + "TraceType.cpp", + [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER], + key_fn=lambda fn: fn.root_name, + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp", + }, + env_callable=gen_trace_type_func, + num_shards=5, + sharded_keys={ + "ops_headers", + "trace_method_definitions", + "trace_wrapper_registrations", + }, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_variable_factories.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_variable_factories.py new file mode 100644 index 0000000000000000000000000000000000000000..9916a77385d38f01e83416d4303cb17ac17de700 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_variable_factories.py @@ -0,0 +1,116 @@ +# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables. +# +# This writes one file: variable_factories.h + +from __future__ import annotations + +import re + +import torchgen.api.python as python +from torchgen.api import cpp +from torchgen.api.types import CppSignatureGroup +from torchgen.context import with_native_function +from torchgen.gen import parse_native_yaml +from torchgen.model import NativeFunction, TensorOptionsArguments, Variant +from torchgen.utils import FileManager, mapMaybe + + +OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>") +TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") + + +# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. +# TODO: maybe update the cpp argument API to take optional namespace argument? +def fully_qualified_type(argument_type: str) -> str: + def maybe_optional_type(type: str, is_opt: bool) -> str: + return f"std::optional<{type}>" if is_opt else type + + opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type) + is_opt = opt_match is not None + if opt_match: + argument_type = argument_type[opt_match.start(1) : opt_match.end(1)] + match = TYPE_PATTERN.match(argument_type) + if match is None: + return maybe_optional_type(argument_type, is_opt) + index = match.start(1) + qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}" + return maybe_optional_type(qualified_type, is_opt) + + +def gen_variable_factories( + out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str +) -> None: + native_functions = parse_native_yaml( + native_yaml_path, tags_yaml_path + ).native_functions + factory_functions = [fn for fn in native_functions if is_factory_function(fn)] + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write_with_template( + "variable_factories.h", + "variable_factories.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/variable_factories.h", + "ops_headers": [ + f"#include " for fn in factory_functions + ], + "function_definitions": list(mapMaybe(process_function, factory_functions)), + }, + ) + + +@with_native_function +def is_factory_function(f: NativeFunction) -> bool: + if Variant.function not in f.variants: + return False + + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + return has_tensor_options or name.endswith("_like") + + +@with_native_function +def process_function(f: NativeFunction) -> str | None: + name = cpp.name(f.func) + has_tensor_options = python.has_tensor_options(f) + is_factory = has_tensor_options or name.endswith("_like") + + if Variant.function not in f.variants or not is_factory: + return None + + cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) + sigs = [cpp_sigs.signature] + if cpp_sigs.symint_signature is not None: + sigs.append(cpp_sigs.symint_signature) + r = "" + for sig in sigs: + formals: list[str] = [] + exprs: list[str] = [] + requires_grad = "false" + for arg in sig.arguments(): + qualified_type = fully_qualified_type(arg.type) + if arg.default: + formals.append(f"{qualified_type} {arg.name} = {arg.default}") + else: + formals.append(f"{qualified_type} {arg.name}") + + if isinstance(arg.argument, TensorOptionsArguments): + # note: we remove the requires_grad setting from the TensorOptions because + # it is ignored anyways (and we actually have an assertion that it isn't set + # which would fail otherwise). We handle requires_grad explicitly here + # instead of passing it through to the kernel. + exprs.append( + f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)" + ) + # Manually set the requires_grad bit on the result tensor. + requires_grad = f"{arg.name}.requires_grad()" + else: + exprs.append(arg.name) + + r += f"""\ +inline at::Tensor {sig.name()}({", ".join(formals)}) {{ + at::AutoDispatchBelowADInplaceOrView guard; + return autograd::make_variable(at::{sig.name()}({", ".join(exprs)}), /*requires_grad=*/{requires_grad}); +}} +""" + return r diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_variable_type.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_variable_type.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6ce65bb0bffdbf5c92759ebe55f173a494828f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_variable_type.py @@ -0,0 +1,2203 @@ +# Generates VariableType.h/cpp +# +# **If any changes are being made to the VariableType codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# +# VariableType is a subclass of at::Type that provides the binding code +# necessary to provide a differentiable version of ATen operators. There are a +# number of different things we could mean: +# +# - Given a non-differentiable forward implementation, we might +# directly associate it with a backward implementation to make +# it differentiable. This is the common case. +# +# - Some functions don't need a backwards implementation, because +# backpropagation will never propagate beyond them. There are a +# number of different reasons why this may be the case: +# +# - The function has no differentiable inputs +# - The function's output is not differentiable +# - The function has no data dependency on its input +# +# - Some function don't need a backwards implementation because they +# are implemented as a composition of other (differentiable) ATen +# functions. These are dispatched directly to the Type superclass, +# which will in turn dispatch back to VariableType for its +# differentiable subcomponents. +# + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +from torchgen.api import cpp +from torchgen.api.autograd import ( + DifferentiableInput, + dispatch_strategy, + ForwardDerivative, + gen_differentiable_outputs, + is_differentiable, + NativeFunctionWithDifferentiabilityInfo, + SavedAttribute, +) +from torchgen.api.types import ( + ArrayRefCType, + BaseCppType, + BaseCType, + Binding, + intArrayRefT, + iTensorListRefT, + ListCType, + MutRefCType, + OptionalCType, + scalarT, + SpecialArgName, + stringT, + symIntArrayRefT, + TENSOR_LIST_LIKE_CTYPES, + tensorListT, + tensorT, + TupleCType, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.context import ( + native_function_manager, + with_native_function, + with_native_function_and, +) +from torchgen.model import ( + Argument, + BaseType, + ListType, + NativeFunction, + SchemaKind, + SelfArgument, + TensorOptionsArguments, +) +from torchgen.utils import FileManager, mapMaybe + +from .context import with_native_function_with_differentiability_info_and_key +from .gen_inplace_or_view_type import ( + ALL_VIEW_FUNCTIONS, + ASSIGN_RETURN_VALUE, + AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION, + gen_formals, + get_base_name, + get_view_info, + is_tensor_list_type, + is_tensor_type, + METHOD_DEFINITION, + modifies_arguments, + TMP_VAR, + unpack_args, + unpacked_name, + use_derived, + WRAPPER_REGISTRATION, +) +from .gen_trace_type import ( + get_return_value, + MANUAL_AUTOGRAD_AND_TRACER, + MANUAL_BACKEND, + tie_return_values, + type_wrapper_name, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + +# We don't set or modify grad_fn on these methods. Generally, they return +# tensors that have requires_grad=False. In-place functions listed here will +# not examine or modify requires_grad or grad_fn. +# NB: this does NOT include overload name +DONT_REQUIRE_DERIVATIVE = { + # These only depend on the input Tensor's shape and device, not the data + "empty_like", + "ones_like", + "full_like", + "zeros_like", + "rand_like", + "randn_like", + "new_empty", + "new_empty_strided", + "new_full", + "new_zeros", + "new_ones", + # These are only implemented on integral types + "__and__", + "__iand__", + "__ilshift__", + "__ior__", + "__irshift__", + "__ixor__", + "__lshift__", + "__or__", + "__rshift__", + "__xor__", + # These work on integral data types, and hence don't require derivative + "_sobol_engine_draw", + "_sobol_engine_ff", + "_sobol_engine_scramble_", + "_sobol_engine_initialize_state_", + # This is an unsafe method that is meant to be out of reach of autograd. + "_coalesced_", + # Quantize functions should not record gradients + "quantize_per_tensor", + "quantize_per_channel", + # Functions that return integers should not have output that require gradients + "argmax", + "argmin", + "argsort", + "searchsorted", + "bucketize", + # Functions that return booleans are not differentiable + "isnan", + "isposinf", + "isneginf", + "isinf", + "signbit", + "isin", + "allclose", + # Functions return none are not differentiable + "record_stream", + # These functions are not differentiable + "logical_and", + "logical_xor", + "logical_not", + "logical_or", + # This function returns nested_tensor shape as a tensor that is non-differentiable + "_nested_tensor_size", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", +} + +# The C -> R functions at the time of adding this are still being audited and tested +# but will not error out. +# C -> C, R -> C functions for which backward is correctly implemented and tested +GRADIENT_IMPLEMENTED_FOR_COMPLEX = { + "fill", + "t", + "t_copy", + "view", + "reshape", + "reshape_as", + "view_as", + "view_copy", + "roll", + "clone", + "block_diag", + "diag_embed", + "repeat", + "expand", + "expand_copy", + "flip", + "fliplr", + "flipud", + "rot90", + "nanmean", + "nansum", + "transpose", + "transpose_copy", + "permute", + "permute_copy", + "squeeze", + "squeeze_copy", + "unsqueeze", + "unsqueeze_copy", + "resize", + "resize_as", + "tril", + "triu", + "chunk", + "zero_", + "eq_", + "ne_", + "add", + "__radd__", + "sum", + "_conj", + "sin", + "cos", + "mul", + "sinc", + "sinh", + "cosh", + "__rmul__", + "sgn", + "asin", + "acos", + "sub", + "div", + "cat", + "view_as_complex", + "index_put", + "neg", + "complex", + "select", + "where", + "as_strided", + "as_strided_copy", + "as_strided_scatter", + "slice", + "constant_pad_nd", + "unbind", + "unbind_copy", + "split", + "split_with_sizes", + "unsafe_split", + "split_with_sizes_backward", + "dot", + "vdot", + "cholesky", + "triangular_solve", + "mm", + "_unsafe_view", + "mv", + "outer", + "bmm", + "diagonal", + "alias", + "atan", + "log", + "log10", + "log1p", + "log2", + "logaddexp", + "logsumexp", + "logcumsumexp", + "reciprocal", + "tan", + "pow", + "rsqrt", + "tanh", + "tanh_backward", + "asinh", + "acosh", + "atanh", + "take", + "fill_", + "exp", + "exp2", + "expm1", + "nonzero", + "mean", + "std_mean", + "var_mean", + "inverse", + "solve", + "linalg_cholesky", + "addcmul", + "addcdiv", + "matrix_exp", + "linalg_matrix_exp", + "_linalg_eigh", + "cholesky_solve", + "linalg_qr", + "_linalg_svd", + "_fft_c2c", + "_fft_r2c", + "linalg_solve", + "sqrt", + "stack", + "gather", + "index_select", + "index_add_", + "linalg_inv", + "linalg_inv_ex", + "baddbmm", + "addbmm", + "addmm", + "addmv", + "addr", + "linalg_householder_product", + "ormqr", + "reflection_pad1d", + "reflection_pad2d", + "reflection_pad3d", + "linalg_cholesky_ex", + "linalg_eig", + "diagonal_copy", + "diagonal_scatter", + "alias_copy", + "select_backward", + "diagonal_backward", + "slice_backward", + "reflection_pad1d_backward", + "reflection_pad2d_backward", + "reflection_pad3d_backward", + "_sparse_sparse_matmul", + "replication_pad1d", + "replication_pad2d", + "replication_pad3d", + "put", + "put_", + "_to_copy", + "replication_pad1d_backward", + "replication_pad2d_backward", + "replication_pad3d_backward", + "diag", + "masked_scatter", + "masked_select", + "index_add", + "index_fill", + "trace", + "polar", + "cumsum", + "rsub", + "eig", + "lerp", + "linalg_vector_norm", + "cumprod", + "prod", + "index_copy", + "lu", + "unfold", + "unfold_backward", + "index", + "masked_fill", + "masked_scatter_backward", + "linalg_cross", + "lu_unpack", + "renorm", + "_conj_physical", + "linalg_lu_factor_ex", + "scatter", + "scatter_add", + "sigmoid", + "sigmoid_backward", + "sparse_mask", + "trapezoid", + "cumulative_trapezoid", + "conj_physical_", + "_neg_view", + "_reshape_alias", + "_reshape_copy", + "_linalg_det", + "lu_solve", + "linalg_solve_triangular", + "linalg_pinv", + "linalg_lstsq", + "unfold_copy", + "col2im", + "im2col", + "cholesky_inverse", + "to_sparse", + "sparse_sampled_addmm", + "linalg_lu", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "linalg_lu_solve", + "_linalg_slogdet", + "_linalg_solve_ex", + "_unsafe_index", + "_unsafe_index_put", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", +} + +GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { + "_to_dense", + "_coalesce", + "coalesce", + "values", + "_sparse_coo_tensor_with_dims_and_tensors", + "_sparse_addmm", +} + +GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX) + +# Some operators invalidate the grad_accumulator. Let's reset it. +RESET_GRAD_ACCUMULATOR = {"set_", "resize_"} + +# NOTE [ TensorImpl and Storage Pointer Sanity Checks ] +# +# We check the following properties: +# 1) A function should never change the input tensors' underlying c10::TensorImpl +# pointers or c10::Storage pointers, even if it modifies its input tensors (via +# inplace or out-variants) +# If the function does not modify its arguments, we also check the following properties +# pertaining to its output: +# 2) Its TensorImpl has use_count of 1 (or 2 if it has a PyObject) +# 3) If the function is a view function, it has the same StorageImpl as that of +# the input it is aliased with. Otherwise, its StorageImpl has use_count of 1 +# +# The following code templates implement the checks for this invariant: +SAVE_TENSOR_STORAGE = CodeTemplate( + """\ +auto ${tensor_name}_storage_saved = + ${tensor_name}.has_storage() ? ::std::optional(${tensor_name}.storage()) : ::std::nullopt; +""" +) + + +# If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2) +ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate( + """\ +if (${tensor_name}_storage_saved.has_value() && + !at::impl::dispatch_mode_enabled() && + !at::impl::tensor_has_dispatch(${tensor_name}) && + !at::impl::tensor_has_dispatch(${out_tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage())); +""" +) + +SAVE_TENSORLIST_STORAGE = CodeTemplate( + """\ +std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const Tensor& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_storage() ? ::std::optional(tensor.storage()) : ::std::nullopt); +""" +) + +ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage())); +} +""" +) + +SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate( + """\ +std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size()); +for (const ::std::optional& tensor : ${tensorlist_name}) + ${tensorlist_name}_storage_saved.push_back( + tensor.has_value() && tensor->has_storage() ? ::std::optional(tensor->storage()) : ::std::nullopt); +""" +) + +ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of( + static_cast<::std::optional>(${tensorlist_name}[i])->storage())); +} +""" +) + +SAVE_TENSOR_IMPL = CodeTemplate( + """\ +c10::intrusive_ptr ${tensor_name}_impl_saved; +if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr(); +""" +) + +ENFORCE_SAME_TENSOR_IMPL = CodeTemplate( + """\ +if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr()); +""" +) + +ENFORCE_TENSOR_IMPL_USE_COUNT = CodeTemplate( + """\ +if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) + TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() == expected_fresh_use_count(${tensor_name}), "function: ${fn_name}"); +""" +) + +ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate( + """\ +if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) { + TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}"); +} +""" +) + +SAVE_TENSORLIST_IMPL = CodeTemplate( + """\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) + if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr(); +""" +) + +ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name})) + TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr()); +} +""" +) + +SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate( + """\ +std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size()); +for (size_t i=0; i<${tensorlist_name}.size(); i++) { + ::std::optional t = ${tensorlist_name}[i]; + if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr(); +} +""" +) + +ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate( + """\ +for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) { + if (${tensorlist_name}_impl_saved[i]) + TORCH_INTERNAL_ASSERT( + ${tensorlist_name}_impl_saved[i] == static_cast<::std::optional>(${tensorlist_name}[i])->getIntrusivePtr()); +} +""" +) + +# The following list contains functions that we don't enforce the invariant on. +DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = { + # These functions are expected to change impl or storage of input tensors + "set_", + "_cudnn_rnn_flatten_weight", + "_unsafe_masked_index", + "_unsafe_masked_index_put_accumulate", +} +DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = { + # These non-inplace, non-out functions return tensors with use_count > 1 + # Therefore, they MAY (but not necessarily) return one of its inputs as-is + # See https://github.com/pytorch/pytorch/issues/60426 for more information + "_embedding_bag", + "_embedding_bag_forward_only", + "q_per_channel_scales", + "q_per_channel_zero_points", + "lu_unpack", + "_cudnn_rnn_backward", + # The below failed StorageImpl use_count check but we skip tensor_impl check + # just in case + "_cudnn_rnn", + "dequantize_self", + # lift() should never actually be called with a requires_grad=True tensor, + "lift", + "lift_fresh", + "lift_fresh_copy", + # Nested Tensors related functions + # _nested_tensor_size() should never actually be called with requires_grad=True tensor + "_nested_tensor_size", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", +} + +DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = { + # These non-view functions return tensors with storage use_count != 1 + "_slow_conv2d_forward", + "slow_conv3d_forward", + "channel_shuffle", + # If an input is returned as-is in output, we cannot guarantee its storage_impl + # use count to be 1 either. + *DONT_ENFORCE_TENSOR_IMPL_USE_COUNT, +} +# END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ] + +DECLARE_GRAD_FN = CodeTemplate( + """\ +std::shared_ptr<${op}> grad_fn; +""" +) + +DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate( + """\ +std::vector> grad_fns; +""" +) + +SETUP_ANY_REQUIRES_GRAD = CodeTemplate( + """\ +[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} ); +${extra_differentiability_conditions} +""" +) + +SETUP_DERIVATIVE = CodeTemplate( + """\ +if (_any_requires_grad) { + ${setup} +} +""" +) + +SETUP_NONE_REQUIRES_GRAD = CodeTemplate( + """\ +if (compute_requires_grad( ${args_to_check} )) { + throw_error_out_requires_grad("${base_name}"); +} +""" +) + +ASSIGN_GRAD_FN = CodeTemplate( + """\ +grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); +grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); +""" +) + +# note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i` +# while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars. +ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate( + """\ +for (const auto& i : c10::irange( ${irange} )) { + const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives}); + check_inplace(self[i], ith_requires_grad); + grad_fns.push_back([&]() -> std::shared_ptr<${op}> { + if (!ith_requires_grad) { + return nullptr; + } else { + auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode); + grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} )); + return grad_fn; + } + }()); +} +""" +) + +CALL_REDISPATCH = CodeTemplate( + """\ +at::redispatch::${api_name}(${unpacked_args})""" +) +# If the non-variable operation has return values, we use the `tmp` variable to hold the +# values temporarily and pass the values to the return variables outside of the +# `at::AutoDispatchBelowAutograd` guard block. +DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate( + """\ +auto ${tmp_var} = ([&]() { + if (${any_has_forward_grad}) { + static c10::OperatorName full_name("aten::${op_name}", "${op_overload}"); + static ::std::optional opt_op = c10::Dispatcher::singleton().findSchema(full_name); + return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names}); + } else { + ${guard} + return ${base_type_call}; + } +})(); +""" +) + +DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate( + """\ +auto ${tmp_var} = ([&]() { + ${guard} + return ${base_type_call}; +})(); +""" +) + +DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate( + """\ +{ + ${guard} + ${base_type_call}; +} +""" +) + +SET_HISTORY = CodeTemplate( + """\ +if (grad_fn) { + ${fn}_history(${differentiable_outputs}, grad_fn); +} +""" +) + +LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate( + """\ +if (!grad_fns.empty()) { + ${preamble} + for (const auto& i : c10::irange(grad_fns.size())) { + auto grad_fn = grad_fns[i]; + if (grad_fn != nullptr) { + ${statements} + } + } +} +""" +) + +CONDITIONAL = CodeTemplate( + """\ +if (${cond}) { + ${statements} +} +""" +) + +RUN_ONLY_IN_DEBUG_MODE = CodeTemplate( + """\ +#ifndef NDEBUG +${statements} +#endif +""" +) + +FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate( + """\ +isFwGradDefined(${req_inp})\ +""" +) +FW_DERIVATIVE_SIZE_CHECK_TEMPLATE = CodeTemplate( + """\ +TORCH_CHECK( + self.size() == ${inp_name}.size(), + "Tensor lists must have the same number of tensors, got ", + self.size(), + " and ", + ${inp_name}.size()); +""" +) + +FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate( + """\ +isFwGradDefinedTensorList(${req_inp})\ +""" +) + +FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate( + """\ +auto ${inp_name}_t_raw = toNonOptFwGrad(${inp}); +auto ${inp_name}_tensor = toNonOptTensor(${inp}); +auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined()) + ? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sym_sizes(), ${inp_name}_tensor.options()); +""" +) + +FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE = CodeTemplate( + """\ +update_wrapped_number(${inp_name}_tensor, ${inp_name}_t); +""" +) + +FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate( + """\ +auto ${inp_name}_p = toNonOptPrimal(${inp}); +""" +) + +FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate( + """\ +if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); +} +""" +) + +FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate( + """\ +for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) { + auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i]; + if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace}); + } +} +""" +) + +FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate( + """\ +if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined() + && ${out_arg}.defined()) { + ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false); +} +""" +) + +FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate( + """\ +if (${out_arg}_new_fw_grad_opt.has_value()) { + auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value(); + TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size()); + for (const auto i : c10::irange(${out_arg}.size())) { + if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) { + // The hardcoded 0 here will need to be updated once we support multiple levels. + ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace}); + } + } +} +""" +) + +FW_DERIVATIVE_TEMPLATE = CodeTemplate( + """\ +${fw_grad_opt_definition} +if (${requires_fw_grad}) { + ${unpacked_arguments} + ${out_arg}_new_fw_grad_opt = ${formula}; +} +""" +) + +FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate( + """\ +${fw_grad_opt_definition} +for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) { + if (${any_has_forward_grad_for_current_index}) { + ${unpacked_arguments} + ${vector_of_optional_tensor}[i] = ${formula}; + } +} +""" +) + +FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate( + """\ +TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}"); +""" +) + +FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate( + """\ +for (const auto& _t: ${arg}) { + TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}"); +} +""" +) + + +def gen_variable_type( + out: str, + native_yaml_path: str, + tags_yaml_path: str, + fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, + used_keys: set[str], +) -> None: + """VariableType.h and VariableType.cpp body + + This is the at::Type subclass for differentiable tensors. The + implementation of each function dispatches to the base tensor type to + compute the output. The grad_fn is attached to differentiable functions. + """ + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + fm.write( + "VariableType.h", + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/VariableType.h" + }, + ) + + # helper that generates a TORCH_LIBRARY_IMPL macro for each + # dispatch key that appears in derivatives.yaml + def wrapper_registrations(used_keys: set[str]) -> str: + library_impl_macro_list: list[str] = [] + for key in sorted(used_keys): + dispatch_key = key + if key == "Default": + dispatch_key = "Autograd" + library_impl_macro = ( + f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) " + + "{\n" + + "${" + + f"wrapper_registrations_{key}" + + "}\n}" + ) + library_impl_macro_list += [library_impl_macro] + return "\n\n".join(library_impl_macro_list) + + # Generate a new template from VariableType.cpp which replaces ${wrapper_registrations} + # with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml + fm1 = FileManager( + install_dir=out + "/templates", template_dir=template_path, dry_run=False + ) + fm1.write( + "VariableType.cpp", + lambda: { + "type_derived_method_definitions": "\n\n".join( + [ + "${" + f"type_derived_method_definitions_{key}" + "}" + for key in sorted(used_keys) + ] + ), + "wrapper_registrations": wrapper_registrations(used_keys), + }, + ) + + # Generate final VariableType_*.cpp files from the generated template + fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False) + + sharded_keys = set( + [f"type_derived_method_definitions_{key}" for key in sorted(used_keys)] + + [f"wrapper_registrations_{key}" for key in sorted(used_keys)] + ) + # NOTE: see Note [Sharded File] at the top of the VariableType.cpp + # template regarding sharding of the generated files. + fm2.write_sharded( + "VariableType.cpp", + [fn for fn in fns_with_diff_infos if use_derived(fn)], + key_fn=lambda fn: cpp.name(fn.func.func), + base_env={ + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/VariableType.cpp", + }, + env_callable=gen_variable_type_func, + num_shards=5, + sharded_keys=sharded_keys, + ) + + +@with_native_function_and +def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str: + return WRAPPER_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name, + type_wrapper_name=type_wrapper_name(f, key), + class_type="VariableType", + ) + + +def gen_variable_type_func( + fn: NativeFunctionWithDifferentiabilityInfo, +) -> dict[str, list[str]]: + f = fn.func + result = {} + with native_function_manager(f): + name = cpp.name(f.func) + formals = gen_formals(f) + + if ( + fn.info is None + and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR + and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE + and len(gen_differentiable_outputs(fn)) > 0 + and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE + and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT + and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT + ): + # NOTE: [ Registering AutogradNotImplemented boxed kernel ] + # + # When there is no derivatives.yaml entry, we register a generic boxed + # NotImplemented kernel to set grad_fn to be NotImplemented, so that forward + # proceeds as usual but an error is properly produced on backward. + # TODO: it would be nice to not have these special cases + # + # There are several cases where still let codegen handle it: + # 1) ops that need to reset grad accumulator (we let codegen handle this case + # because) the list is (currently) only accessible in Python. + # 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes + # autograd a fallthrough with NDEBUG checks. This can be useful for when all + # outputs are integral. + # 3) When there are no differentiable outputs. This is similar to (2). + # 4) There are certain ops where we skip certain NDEBUG checks. this is similar + # to (1). + type_definition = "" + wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute( + unqual_operator_name_with_overload=f.func.name + ) + result["type_derived_method_definitions_Default"] = [type_definition] + result["wrapper_registrations_Default"] = [wrapper_registration] + else: + if not fn.info: + key = "Default" + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), + type_wrapper_name=type_wrapper_name(f, key), + type_definition_body=emit_body(fn, key), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f, key) + result[f"type_derived_method_definitions_{key}"] = [type_definition] + result[f"wrapper_registrations_{key}"] = [wrapper_registration] + else: + for key in fn.info: + type_definition = METHOD_DEFINITION.substitute( + return_type=cpp.returns_type( + f.func.returns, symint=True + ).cpp_type(), + type_wrapper_name=type_wrapper_name(f, key), + type_definition_body=emit_body(fn, key), + formals=formals, + ) + wrapper_registration = gen_wrapper_registration(f, key) + result[f"type_derived_method_definitions_{key}"] = [type_definition] + result[f"wrapper_registrations_{key}"] = [wrapper_registration] + # See Note [Manual Backend kernels] + assert (name in MANUAL_BACKEND) == f.manual_kernel_registration + # If you want to register a kernel to Autograd, you must make the op abstract. + # In other words, this op must have dispatch section in native_functions.yaml. + if name in MANUAL_AUTOGRAD_AND_TRACER or ( + fn.info and any(info.has_derivatives for info in fn.info.values()) + ): + msg = ( + f"There's a formula for {name}(or its functional variant) in derivatives.yaml. " + f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA " + f"or CompositeExplicitAutograd in native_functions.yaml. Please see " + f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword " + f"for instructions to choose the right dispatch keyword." + ) + assert f.is_abstract, msg + + return result + + +_foreach_ops_without_differentiability_info = { + # No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`. + ("_foreach_maximum", "Scalar"), + ("_foreach_maximum", "ScalarList"), + ("_foreach_minimum", "Scalar"), + ("_foreach_minimum", "ScalarList"), + # No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor. + ("_foreach_addcdiv", "Tensor"), + ("_foreach_addcmul", "Tensor"), + ("_foreach_copy", ""), +} + +_foreach_ops_with_different_arity = { + # These ops lack `alpha` of scaling factor to applied to the right hand side argument. + ("_foreach_add", "Scalar"), + ("_foreach_add", "ScalarList"), + ("_foreach_sub", "Scalar"), + ("_foreach_sub", "ScalarList"), +} + + +@with_native_function_with_differentiability_info_and_key +def emit_body( + fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" +) -> list[str]: + assert dispatch_strategy(fn) == "use_derived" + f = fn.func + info = fn.info[key] if fn.info else None + fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else [] + + name = cpp.name(f.func) + inplace = f.func.kind() == SchemaKind.inplace + is_out_fn = f.func.kind() == SchemaKind.out + returns_void = len(f.func.returns) == 0 + base_name = get_base_name(f) + view_info = get_view_info(f) + + is_foreach = name.startswith("_foreach") + is_inplace_foreach = is_foreach and inplace + if is_inplace_foreach: + inplace_foreacharg2refarg: dict[Argument, Argument] = {} + refargname2inplace_foreacharg: dict[str, Argument] = {} + base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name) + if info is None: + assert ( + base_name_and_overload_name + in _foreach_ops_without_differentiability_info + ), ( + f"{'.'.join(base_name_and_overload_name)} should have a differentiability info" + ) + else: + assert ( + len(f.func.arguments.flat_non_out) + == len(info.func.func.arguments.flat_non_out) + ) or (base_name_and_overload_name in _foreach_ops_with_different_arity), ( + f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args " + f"but the reference has {len(info.func.func.arguments.flat_non_out)}" + ) + for foreach_arg, ref_arg in zip( + f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out + ): + foreach_arg_type = foreach_arg.type + if isinstance(foreach_arg_type, ListType): + foreach_arg_type = foreach_arg_type.elem + assert foreach_arg_type == ref_arg.type + inplace_foreacharg2refarg[foreach_arg] = ref_arg + refargname2inplace_foreacharg[ref_arg.name] = foreach_arg + + def gen_differentiable_input( + arg: Argument | SelfArgument | TensorOptionsArguments, + ) -> DifferentiableInput | None: + if isinstance(arg, TensorOptionsArguments): + return None + a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg + + # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove. + # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are + # not handled properly as they are irrelevant for this codegen. + cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type() + + if not is_differentiable(a.name, a.type, info): + return None + return DifferentiableInput( + name=a.name, + type=a.type, + cpp_type=cpp_type, + ) + + @with_native_function + def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]: + arguments = list(f.func.arguments.non_out) + if is_inplace_foreach and info is not None: + for i, arg in enumerate(f.func.arguments.flat_non_out): + if arg in inplace_foreacharg2refarg: + # note(crcrpar): From what I understand, what matters is only the name. + # Thus originally I only replace argument only when the names are different. + # TODO(crcrpar): Make it simpler. + mapped_arg = inplace_foreacharg2refarg[arg] + arguments[i] = Argument( + mapped_arg.name, + mapped_arg.type, + mapped_arg.default, + mapped_arg.annotation, + ) + return list(mapMaybe(gen_differentiable_input, arguments)) + + def find_args_with_derivatives( + differentiable_inputs: list[DifferentiableInput], + ) -> list[DifferentiableInput]: + """Find arguments that have derivative definitions""" + if info is None or not info.has_derivatives: + return differentiable_inputs + names = {name for d in info.derivatives for name in d.var_names} + differentiable = [arg for arg in differentiable_inputs if arg.name in names] + if len(differentiable) != len(names): + missing = names - {arg.name for arg in differentiable} + raise RuntimeError( + f"Missing arguments for derivatives: {missing} in {info.name}" + ) + return differentiable + + differentiable_inputs = gen_differentiable_inputs(f) + args_with_derivatives = find_args_with_derivatives(differentiable_inputs) + differentiable_outputs = gen_differentiable_outputs(fn, key) + + undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or ( + name in DONT_REQUIRE_DERIVATIVE + ) + + requires_derivative = ( + (not undifferentiable) + and (len(differentiable_inputs) > 0) + and ( + (len(differentiable_outputs) > 0) + # note(crcrpar): In-place foreach functions are a void function. + or is_inplace_foreach + ) + ) + + if ( + info is not None + and info.has_derivatives + and not requires_derivative + # out= ops are allowed to have zero returns which cause requires_derivative to be False + # we shouldn't error out though (out= ops for autograd just redispatch) + and len(f.func.returns) > 0 + ): + raise RuntimeError( + f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative" + ) + + # note(crcrpar): In-place foreach functions do not support forward AD + if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach: + assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len( + differentiable_outputs + ), ( + "Expected the number of forward derivatives implemented to match the " + "number of differentiable outputs. NB: This only applies when at least " + "one forward derivative is implemented. Not implementing any forward " + "derivatives is also okay, and we would require inputs to the op to " + "not have associated tangents in that case." + ) + + try_jit_decomposition = ( + requires_derivative + and len(fw_derivatives) == 0 + and (not modifies_arguments(f)) + and (not returns_void) + ) + + def emit_save_inputs() -> list[str]: + setup: list[str] = [] + if info is None or not info.has_derivatives: + return setup + + has_tensorlist_arg = any( + is_tensor_list_type(arg.type) for arg in args_with_derivatives + ) + + # We don't want to save tensors if we know that they will never be used + # when computing the derivative, so we add guards to those statements + def guard_for(arg: SavedAttribute) -> str | None: + assert info is not None + + # It's hard to determine the edge offset if we have TensorLists + # NOTE(crcrpar): in-place foreach functions' arguments include tensorlist + # but their derivatives don't use it, so let them bypass this check. + if has_tensorlist_arg and (not is_inplace_foreach): + return None + + # Empirical evaluation of the cases where we insert those guards in + # backward show that they are somewhat useless. E.g. there's no need + # to guard on some values captured from forward, because they had to + # require_grad if the backward function even gets executed. I don't + # have any good ideas for detecting those cases, so I simply disabled the + # checks. + if "backward" in info.name: + return None + + # If there's a single derivative we could compute, we already have + # a requires_grad check that is sufficient + if len(args_with_derivatives) <= 1: + return None + + # We really only care about trimming down the amount of tensors we save + if arg.nctype.type != BaseCType(tensorT): + return None + + # We want to emit simple guards, so we only allow that if checking one + # input is enough to determine whether we need that value + used_in = [d for d in info.derivatives if arg in d.saved_inputs] + assert len(used_in) > 0 + if len(used_in) != 1: + return None + derivative = used_in[0] + + # Case with multioutput formulas + # TODO: process all derivative formulas!!! + if len(derivative.var_names) != 1: + wrap_opt_if_start = derivative.formula.find( + f"wrap_opt_if({arg.nctype.name}" + ) + if wrap_opt_if_start == -1: + return None + + wrap_opt_if_match = re.match( + rf"wrap_opt_if\({arg.nctype.name},(.*?)\)", + derivative.formula[wrap_opt_if_start:], + ) + assert wrap_opt_if_match is not None + + # Condition is between 'wrap_opt_if(var_name,' and ')'. + condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1) + wrap_opt_if_condition = wrap_opt_if_match.group(0)[ + condition_slice + ].strip() + # replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)' + wrap_opt_if_condition = re.sub( + r"grad_input_mask\[(\d+)\]", + r"grad_fn->should_compute_output(\1)", + wrap_opt_if_condition, + ) + return f"{wrap_opt_if_condition}" + + # Figure out the offset of the edge that uses this variable + derivative_var_name = derivative.var_names[0] + for edge_off, a in enumerate(args_with_derivatives): + if a.name == derivative_var_name: + break + else: + raise AssertionError + return f"grad_fn->should_compute_output({edge_off})" + + if is_inplace_foreach: + save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for) + if save_input_stmts: + setup.append( + LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble="", statements=save_input_stmts + ) + ) + else: + setup.extend(save_variables(info.all_saved_inputs, False, guard_for)) + for arg in args_with_derivatives: + if is_tensor_list_type(arg.type): + setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();") + return setup + + def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]: + body: list[str] = [] + if is_out_fn: + # For out functions, ensure that no input or output requires grad + body.append(DECLARE_GRAD_FN.substitute(op="Node")) + body.append( + SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_inputs], + ) + ) + body.append( + SETUP_NONE_REQUIRES_GRAD.substitute( + base_name=base_name, + args_to_check=[arg.name for arg in differentiable_outputs], + ) + ) + return body + + op = info.op if info is not None and info.has_derivatives else "NotImplemented" + setup = [] + if not is_inplace_foreach: + setup.extend( + ASSIGN_GRAD_FN.substitute( + op=op, + op_ctor="" + if info is not None and info.has_derivatives + else f'"{cpp.name(f.func)}"', + args_with_derivatives=[arg.name for arg in args_with_derivatives], + ).split("\n") + ) + else: + # note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList. + list_like_arg = "self" + args = [arg.name for arg in args_with_derivatives] + for i, arg in enumerate(args): + if is_inplace_foreach and info is not None: + if arg in refargname2inplace_foreacharg: + foreach_arg = refargname2inplace_foreacharg[arg] + args[i] = foreach_arg.name + ( + "[i]" if isinstance(foreach_arg.type, ListType) else "" + ) + else: + if arg == list_like_arg: + args[i] = arg + "[i]" + setup.extend( + ASSIGN_VECTOR_OF_GRAD_FN.substitute( + op=op, + op_ctor="" + if info is not None and info.has_derivatives + else f'"{cpp.name(f.func)}"', + args_with_derivatives=args, + irange=f"{list_like_arg}.size()", + ).split("\n") + ) + setup.extend(emit_save_inputs()) + + body.extend( + emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives) + ) + declare_grad_fn_template = ( + DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN + ) + body.append(declare_grad_fn_template.substitute(op=op)) + body.append(SETUP_DERIVATIVE.substitute(setup=setup)) + return body + + def emit_check_if_in_complex_autograd_allowlist() -> list[str]: + body: list[str] = [] + if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: + return body + for arg in differentiable_outputs: + name = arg.name + # TODO: should be `arg.type.is_tensor_like()`? + if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES: + body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");') + return body + + def emit_check_no_requires_grad( + tensor_args: list[DifferentiableInput], + args_with_derivatives: list[DifferentiableInput], + ) -> list[str]: + """Checks that arguments without derivatives don't require grad""" + body: list[str] = [] + for arg in tensor_args: + if arg in args_with_derivatives: + continue + arg_name = arg.name + if info and arg_name in info.non_differentiable_arg_names: + continue + if arg_name == "output": + # Double-backwards definitions sometimes take in 'input' and + # 'output', but only define the derivative for input. + continue + body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");') + return body + + def emit_original_self_definition() -> list[str]: + body: list[str] = [] + if inplace: + if is_inplace_foreach: + body.append( + "std::vector<::std::optional> original_selfs(self.size());" + ) + else: + body.append("::std::optional original_self;") + + all_forward_grad_cond = [] + for derivative in fw_derivatives: + if derivative.required_original_self_value: + all_forward_grad_cond.append( + get_any_has_forward_grad_name(derivative.var_names) + ) + + if all_forward_grad_cond: + if not is_inplace_foreach: + body.append(f"if ({' || '.join(all_forward_grad_cond)}) {{") + body.append(" original_self = self.clone();") + body.append("}") + else: + current_all_forward_grad_cond = [ + f"{cond}[i]" for cond in all_forward_grad_cond + ] + body.append("for (const auto& i : c10::irange(self.size())) {") + body.append( + f" if ({' || '.join(current_all_forward_grad_cond)}) {{" + ) + body.append(" original_selfs[i] = self[i].clone();") + body.append(" }") + body.append("}") + + return body + + def save_variables( + saved_variables: Sequence[SavedAttribute], + is_output: bool, + guard_for: Callable[[SavedAttribute], str | None] = lambda name: None, + ) -> Sequence[str]: + # assign the saved variables to the generated grad_fn + stmts: list[str] = [] + for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)): + name = ( + arg.nctype.name.name + if isinstance(arg.nctype.name, SpecialArgName) + else arg.nctype.name + ) + foreacharg: Argument | None = None + is_foreacharg_list_type: bool = False + type = arg.nctype.type + expr = arg.expr + stmts_prepend = None + if is_inplace_foreach and info is not None: + # todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`. + # for now the example assert would fail. + name_to_query = name.split("_scalar_type")[0] + if name_to_query in refargname2inplace_foreacharg: + foreacharg = refargname2inplace_foreacharg[name_to_query] + is_foreacharg_list_type = isinstance(foreacharg.type, ListType) + if foreacharg is not None: + name_in_expr = ( + f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}" + ) + src_name = name + if "_scalar_type" in src_name: + split_src_name = src_name.split("_scalar_type") + assert len(split_src_name) == 2 + src_name = split_src_name[0] + expr = expr.replace(src_name, name_in_expr) + if ( + type == BaseCType(tensorT) + or type == OptionalCType(BaseCType(tensorT)) + or type == MutRefCType(OptionalCType(BaseCType(tensorT))) + or (is_output and type == BaseCType(scalarT)) + ): + # note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored. + var = name + name += "_" + if var == "self" and inplace: + original_self_var = ( + "original_self" + if not is_inplace_foreach + else "original_selfs[i]" + ) + self_var = var if not is_inplace_foreach else var + "[i]" + stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()" + var = f"{original_self_var}.value()" + assert not is_output + if inplace and is_output: + assert name == "result_" + var = ( + "self[i]" + if is_inplace_foreach or is_foreacharg_list_type + else "self" + ) + is_inplace_view = f"{var}.is_view()" + expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})" + else: + expr = f"SavedVariable({var}, {str(is_output).lower()})" + if foreacharg is not None and "original_selfs" not in expr: + # pyrefly: ignore [unbound-name] + expr = expr.replace(src_name, name_in_expr) + elif ( + type == BaseCType(tensorListT) + or type == ListCType(OptionalCType(BaseCType(tensorT))) + or type == BaseCType(iTensorListRefT) + or type == VectorCType(BaseCType(tensorT)) + ): + # See Note [nuanced return type of out-of-place foreach functions] + if type == VectorCType(BaseCType(tensorT)): + assert is_foreach and is_output + expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})" + name += "_" + elif type == BaseCType(intArrayRefT): + expr = expr + ".vec()" + elif type == BaseCType(symIntArrayRefT): + expr = expr + ".vec()" + elif type == BaseCType(stringT): + expr = f"std::string({expr})" + elif type == OptionalCType(BaseCType(stringT)): + expr = f"{expr}.has_value() ? ::std::optional(std::string({expr}.value())) : ::std::nullopt" + elif type == ArrayRefCType( + elem=BaseCType(type=BaseCppType(ns="at", name="Scalar")) + ): + expr = expr + ".vec()" + + guard = guard_for(arg) + if guard is None: + if stmts_prepend: + stmts.append(f"{stmts_prepend};") + stmts.append(f"grad_fn->{name} = {expr};") + else: + stmts.append(f"if ({guard}) {{") + if stmts_prepend: + stmts.append(f" {stmts_prepend};") + stmts.append(f" grad_fn->{name} = {expr};") + stmts.append("}") + return stmts + + # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons: + # - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS. + # - redispatch() avoids a redundant call to RecordFunction, which was already called right before + # we entered this autograd kernel. + def emit_dispatch_call( + f: NativeFunction, input_base: str, unpacked_args: Sequence[str] + ) -> str: + """Dispatch call via function in a namespace or method on Tensor.""" + # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. + # Ops also always have a function variant of the redispatch API. + # See Note [Plumbing Keys Through The Dispatcher] for details. + dispatch_key_set = "ks & c10::after_autograd_keyset" + call = CALL_REDISPATCH.substitute( + api_name=cpp.name( + f.func, + faithful_name_for_out_overloads=True, + symint_overload=f.func.has_symint(), + ), + unpacked_args=[dispatch_key_set] + list(unpacked_args), + ) + return call + + def wrap_output( + f: NativeFunction, unpacked_bindings: list[Binding], var: str + ) -> str: + call = "" + rhs_value: str | None = None + if not any(r.type.is_tensor_like() for r in f.func.returns): + rhs_value = var + else: + rhs_value = f"std::move({var})" + assert rhs_value is not None + call += ASSIGN_RETURN_VALUE.substitute( + return_values=tie_return_values(f), rhs_value=rhs_value + ) + return call + + def check_tensorimpl_and_storage( + call: str, unpacked_bindings: list[Binding] + ) -> str: + # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ] + stmts_before_call: list[str] = [] + stmts_after_call: list[str] = [] + + if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: + return call + + # Check properties of inputs (enforce (1)) + for unpacked_binding in unpacked_bindings: + arg = unpacked_binding.name + noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref() + if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType( + iTensorListRefT + ): + stmts_before_call += [ + SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg), + ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): + stmts_before_call += [ + SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg), + SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( + tensorlist_name=arg + ), + ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( + tensorlist_name=arg + ), + ] + elif noref_cpp_type == BaseCType(tensorT): + stmts_before_call += [ + SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), + SAVE_TENSOR_IMPL.substitute(tensor_name=arg), + ] + stmts_after_call += [ + ENFORCE_SAME_TENSOR_STORAGE.substitute( + tensor_name=arg, out_tensor_name=arg + ), + ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg), + ] + + assert (stmts_before_call and stmts_after_call) or ( + not stmts_before_call and not stmts_after_call + ) + + # Check properties of outputs (enforce (2), (3)) + if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out): + base_name = f.func.name.name.base # TODO: should be str(f.func.name.name)? + aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None) + if aliased_arg_name is not None: + aliased_arg_name = unpacked_name(aliased_arg_name) + for i, (ret, ret_name) in enumerate( + zip(f.func.returns, cpp.return_names(f)) + ): + noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref() + if noref_cpp_type == BaseCType(tensorT): + if aliased_arg_name is not None: + assert i == 0, ( + "Expect non-CompositeImplicitAutograd view function {base} to return single output" + ) + stmts_after_call += [ + ENFORCE_SAME_TENSOR_STORAGE.substitute( + tensor_name=aliased_arg_name, out_tensor_name=ret_name + ) + ] + else: + if ( + type_wrapper_name(f) + not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT + ): + stmts_after_call += [ + ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute( + tensor_name=ret_name, fn_name=type_wrapper_name(f) + ) + ] + + if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: + stmts_after_call += [ + ENFORCE_TENSOR_IMPL_USE_COUNT.substitute( + tensor_name=ret_name, fn_name=type_wrapper_name(f) + ) + ] + + # Currently we don't have any functions that return the following types, but + # we should update the checks once we do + elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))): + raise AssertionError( + f"Please add use_count checks for {noref_cpp_type}" + ) + elif noref_cpp_type == BaseCType(tensorListT): + raise AssertionError( + f"Please add use_count checks for {noref_cpp_type}" + ) + + if stmts_before_call and stmts_after_call: + call = ( + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + + call + + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call) + ) + return call + + def emit_call( + f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool + ) -> str: + # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch + # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure + # the baseType operations still dispatch to non-Variable type, even if the arguments passed + # in are now Variables. + # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details. + unpacked_args = [b.name for b in unpacked_bindings] + base_type_call = emit_dispatch_call(f, "self_", unpacked_args) + + if get_view_info(f) is not None or modifies_arguments(f): + guard = "at::AutoDispatchBelowAutograd guard;" + else: + guard = "at::AutoDispatchBelowADInplaceOrView guard;" + + any_has_forward_grad = ( + get_any_has_fw_grad_cond(derivative=None) + if requires_derivative + else "false" + ) + return_types = ", ".join( + [cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns] + ) + if len(f.func.returns) > 1: + return_types = f"std::tuple<{return_types}>" + + arg_names = [ + a.name + for a in cpp.arguments( + f.func.arguments, + faithful=True, + symint=True, + method=False, + cpp_no_default_args=set(), + ) + ] + + if not modifies_arguments(f) and not returns_void: + if try_jit_decomposition: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute( + base_type_call=base_type_call, + tmp_var=TMP_VAR, + guard=guard, + any_has_forward_grad=any_has_forward_grad, + op_name=cpp.name(f.func), + op_overload=f.func.name.overload_name, + return_types=return_types, + arg_names=arg_names, + ) + else: + call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute( + base_type_call=base_type_call, + tmp_var=TMP_VAR, + guard=guard, + ) + + call += wrap_output(f, unpacked_bindings, TMP_VAR) + else: + assert not try_jit_decomposition + call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute( + base_type_call=base_type_call, guard=guard + ) + call = check_tensorimpl_and_storage(call, unpacked_bindings) + return call + + def emit_history() -> str: + fn = "rebase" if modifies_arguments(f) and view_info is None else "set" + output_names = [r.name for r in differentiable_outputs] + # TODO: flatten allocates a std::vector, which could be expensive + outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute( + outs=output_names if not is_inplace_foreach else "self" + ) + if not is_inplace_foreach: + return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs) + else: + return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble=( + f"auto differentiable_outputs = {outs};\n" + f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());" + ), + statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);", + ) + + def emit_save_outputs() -> str: + if is_out_fn: + # out functions don't currently support differentiation + return "" + if info is not None and info.has_derivatives: + stmts = save_variables(info.all_saved_outputs, True) + if len(stmts) == 0: + return "" + if not is_inplace_foreach: + return CONDITIONAL.substitute(cond="grad_fn", statements=stmts) + else: + return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute( + preamble="", statements=stmts + ) + return "" + + def emit_any_requires_grad() -> list[str]: + extra_condition = "" + if info and info.output_differentiability_conditions: + assert len(info.output_differentiability_conditions) == 1 + extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});" + names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives] + if is_inplace_foreach and info is not None: + for i, arg in enumerate(names_of_args_with_derivatives): + for f_arg, r_arg in inplace_foreacharg2refarg.items(): + if arg == r_arg.name: + names_of_args_with_derivatives[i] = f_arg.name + return [ + SETUP_ANY_REQUIRES_GRAD.substitute( + args_with_derivatives=names_of_args_with_derivatives, + extra_differentiability_conditions=extra_condition, + ) + ] + + def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str: + if len(var_names) == 1: + return f"_any_has_forward_grad_{var_names[0]}" + else: + return f"_any_has_forward_grad_{'_'.join(var_names)}" + + def emit_any_has_forward_grad() -> list[str]: + content: list[str] = [] + if not is_foreach: + for derivative in fw_derivatives: + requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative) + if info and info.output_differentiability_conditions: + assert len(info.output_differentiability_conditions) == 1 + requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}" + content.append( + f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};" + ) + else: + for derivative in fw_derivatives: + bool_vector_name = get_any_has_forward_grad_name(derivative.var_names) + cur_derivative_conditions = [] + for inp in differentiable_inputs: + if derivative.required_inputs_fw_grad is None: + continue + if inp.name not in derivative.required_inputs_fw_grad: + continue + inp_name = ( + inp.name + if not inplace + else refargname2inplace_foreacharg[inp.name].name + ) + inp_type = ( + inp.type + if not inplace + else refargname2inplace_foreacharg[inp.name].type + ) + is_list_type = is_tensor_list_type(inp_type) + if is_list_type: + if inp_name != "self": + content.append( + FW_DERIVATIVE_SIZE_CHECK_TEMPLATE.substitute( + inp_name=inp_name + ) + ) + cur_derivative_conditions.append( + # pyrefly: ignore [bad-argument-type] + FW_DERIVATIVE_CHECK_TEMPLATE.substitute( + req_inp=inp_name + "[i]" + ) + ) + else: + cur_derivative_conditions.append( + # pyrefly: ignore [bad-argument-type] + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name) + ) + + content.append(f"std::vector {bool_vector_name}(self.size());") + content.append("for (const auto& i : c10::irange(self.size())) {") + content.append( + f" {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};" + ) + content.append("}") + return content + + def emit_check_inplace() -> list[str]: + if not inplace: + return [] + return [ + f"check_inplace({arg.name}, _any_requires_grad);" + for arg in differentiable_outputs + ] + + def emit_fw_derivatives() -> list[str]: + content: list[str] = [] + fw_grad_setters: list[str] = [] + for derivative in fw_derivatives: + res = derivative.var_names + if f.func.name.name.inplace: + assert len(res) == 1, ( + "Expected number of outputs to be 1 if function is inplace" + ) + # TODO update this when inplace namings are unified + res = ("self",) + + assert derivative.required_inputs_fw_grad is not None + + unpacked_arguments = "" + for inp in differentiable_inputs: + inp_name = inp.name + is_input_tensorlist = is_foreach and is_tensor_list_type( + inp.type + if not inplace + else refargname2inplace_foreacharg[inp.name].type + ) + input_suffix = "[i]" if is_input_tensorlist else "" + if is_inplace_foreach: + if inp.name in refargname2inplace_foreacharg: + inp_name = refargname2inplace_foreacharg[inp.name].name + zeros_fn = ( + "zeros_symint" + if inplace and inp.name == "self" + else "_efficientzerotensor_symint" + ) + if inp.name in derivative.required_inputs_fw_grad: + unpacked_arguments += ( + FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute( + inp_name=inp.name, + inp=inp_name + input_suffix, + zeros_fn=zeros_fn, + ) + ) + if zeros_fn == "_efficientzerotensor_symint": + unpacked_arguments += ( + FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE.substitute( + inp_name=inp.name + ) + ) + + if inp.name in (derivative.required_inputs_primal or []): + unpacked_arguments += ( + FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( + inp_name=inp.name, + inp=inp_name + input_suffix, + ) + ) + if derivative.required_original_self_value: + input_suffix = "s[i]" if is_inplace_foreach else "" + unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute( + inp_name="original_self", + inp="original_self" + input_suffix, + # pyrefly: ignore [unbound-name] + zeros_fn=zeros_fn, + ) + unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( + inp_name="original_self", + inp="original_self" + input_suffix, + ) + elif inplace and derivative.is_reusing_outplace_formula: + # The gradient wasn't already cloned, do it if grad mode is enabled + unpacked_arguments += ( + "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;" + ) + + if inplace: + is_inplace_str = "true" + else: + is_inplace_str = "false" + + requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names) + + if all( + (isinstance(var_type, BaseType) and var_type.is_tensor_like()) + for var_type in derivative.var_types + ): + # Is there a way to get from BaseType to BaseCType + if len(derivative.var_types) == 1: + opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() + if not is_foreach: + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + assert res[0] == ("result" if not inplace else "self") + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + requires_fw_grad += f" && ({derivative.var_names[0]}.defined())" + else: + tuple_type = TupleCType( + [BaseCType(tensorT)] * len(derivative.var_types) + ) + opt_res_grad_type = OptionalCType(tuple_type).cpp_type() + for idx, single_res in enumerate(res): + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute( + idx=idx, all_res="_".join(res), out_arg=single_res + ) + ) + elif ( + isinstance(derivative.var_types[0], ListType) + and derivative.var_types[0].is_tensor_like() + ): + assert len(derivative.var_types) == 1, ( + "Expected number of outputs to be 1 if function returns ListType" + ) + if not is_foreach: + opt_res_grad_type = OptionalCType( + VectorCType(BaseCType(tensorT)) + ).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + # TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow? + # Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml` + # can reach here. + opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type() + fw_grad_setters.append( + FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute( + out_arg=res[0], is_inplace=is_inplace_str + ) + ) + else: + raise RuntimeError("Unsupported output type for forward derivative") + + if not is_foreach: + fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = ::std::nullopt;" + # View ops create fw_grad that already is a view of the base's fw_grad so just use that + content.append( + FW_DERIVATIVE_TEMPLATE.substitute( + fw_grad_opt_definition=fw_grad_opt_definition, + requires_fw_grad=requires_fw_grad, + formula=derivative.formula, + out_arg="_".join(res), + unpacked_arguments=unpacked_arguments, + ) + ) + else: + # note(crcrpar): Assuming `self` is TensorList. + fw_grad_opt_definition = ( + f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts" + "(self.size(), ::std::nullopt);" + ) + foreach_forward_grad_formula = derivative.formula + _foreach_arg: Argument | DifferentiableInput + if inplace: + for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items(): + # note(crcrpar): Massage only Scalar and ArrayRef here. + if not ( + is_tensor_type(_foreach_arg.type) + or is_tensor_list_type(_foreach_arg.type) + ): + pattern = _foreach_arg.name + if isinstance(_foreach_arg.type, ListType): + pattern += "[i]" + foreach_forward_grad_formula = ( + foreach_forward_grad_formula.replace( + _ref_arg.name, pattern + ) + ) + else: + if ( + "result" in foreach_forward_grad_formula + and "result[i]" not in foreach_forward_grad_formula + ): + foreach_forward_grad_formula = ( + foreach_forward_grad_formula.replace("result", "result[i]") + ) + + content.append( + FW_DERIVATIVE_FOREACH_TEMPLATE.substitute( + fw_grad_opt_definition=fw_grad_opt_definition, + vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts", + any_has_forward_grad_for_current_index=" || ".join( + get_any_has_forward_grad_name(derivative.var_names) + "[i]" + for derivative in fw_derivatives + ), + formula=foreach_forward_grad_formula, + unpacked_arguments=unpacked_arguments, + ) + ) + + # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367 + content.append("\n".join(fw_grad_setters)) + return content + + def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str: + # + # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)") + # + if derivative is None: + # (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs + # - Used in the out_fn case when we want to forbid fw derivatives + # - Used in the case where the fw_derivative is not defined, but we want + # To check if there is a decomposition registered for jvp + to_check: list[str] = [] + for inp in list( + mapMaybe( + gen_differentiable_input, + f.func.arguments.non_out + list(f.func.arguments.out), # type: ignore[operator] + ) + ): + if is_tensor_type(inp.type): + to_check.append( + FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name) + ) + elif is_tensor_list_type(inp.type): + to_check.append( + FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute( + req_inp=inp.name + ) + ) + else: + raise RuntimeError( + f'Unsupported input type for "{name}" when forbidding forward AD usage.' + ) + return f"({' || '.join(to_check)})" + else: + # (2) If derivative is provided, use that information to determine which inputs + # to check fw_grad for + assert derivative.required_inputs_fw_grad is not None + + if len(derivative.required_inputs_fw_grad) == 0: + # Handle functions like stack + # For these, we don't unpack anything and always call the user function + if not ( + len(differentiable_inputs) == 1 + and is_tensor_list_type(differentiable_inputs[0].type) + ): + raise RuntimeError( + f'No differentiable input to "{name}" is a differentiable Tensor (as the provided ' + "forward AD formula does not use any input tangent) even though a forward gradient " + "formula has been defined for it. This case should only happen for function that " + "take a single TensorList as input. All other cases are not supported right now." + ) + any_has_fw_grad = "true" + else: + any_has_fw_grad = " || ".join( + [ + ( + FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE + if is_tensor_list_type(inp.type) + else FW_DERIVATIVE_CHECK_TEMPLATE + ).substitute(req_inp=inp.name) + for inp in differentiable_inputs + if inp.name in derivative.required_inputs_fw_grad + ] + ) + any_has_fw_grad = f"({any_has_fw_grad})" + + return any_has_fw_grad + + def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str: + if is_out_fn: + msg = "because it is an out= function" + else: + msg = ( + "because it has not been implemented yet.\\nPlease file an issue " + "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml " + "so that we can prioritize its implementation." + ) + cond = get_any_has_fw_grad_cond(derivative=None) + return ( + FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg) + if cond != "" + else "" + ) + + body: list[str] = [] + unpack_args_stats, unpacked_bindings = unpack_args(f) + + body.extend(unpack_args_stats) + if requires_derivative: + body.extend(emit_any_requires_grad()) + body.extend(emit_any_has_forward_grad()) + body.extend(emit_check_inplace()) + body.extend(emit_original_self_definition()) + body.extend(setup_derivative(differentiable_inputs)) + + body.append(emit_call(f, unpacked_bindings, try_jit_decomposition)) + if requires_derivative: + # set_flags has to appear after version_counter, because rebase_history + # requires that the counter is incremented before it is called + body.append(emit_history()) + body.extend(emit_check_if_in_complex_autograd_allowlist()) + + if is_out_fn: + body.append(emit_forbid_fw_derivatives(is_out_fn=True)) + else: + if requires_derivative and not try_jit_decomposition: + if len(fw_derivatives) > 0: + body.extend(emit_fw_derivatives()) + else: + body.append(emit_forbid_fw_derivatives()) + + if requires_derivative: + # Save only after the forward AD has been set up + body.append(emit_save_outputs()) + + if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR: + # `inplace` implies that there is exactly one output named `self`, + # so we can keep the generated code easy. If you need to + # `reset_grad_accumulator` in an operator that's not `inplace`, you can + # remove this assert but the code generation will get more elaborate + assert inplace + body.append("reset_grad_accumulator(self);") + if not returns_void: + body.append(f"return {get_return_value(f)};") + return body diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_view_funcs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_view_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc8a2ffcecc4571c5101a265be3a5eeb766473a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/gen_view_funcs.py @@ -0,0 +1,339 @@ +# Generates ViewFuncs.h/cpp +# +# NOTE: If any changes are being made to the ViewFunc codegen please also check +# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp +# The fallback is expected to mimic this codegen, so we should keep the two in sync. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + NamedCType, + SymIntT, + tensorT, + VectorCType, +) +from torchgen.code_template import CodeTemplate +from torchgen.model import Argument, NativeFunction, OptionalType +from torchgen.utils import FileManager + +from .gen_inplace_or_view_type import ( + CALL_DISPATCH, + extract_bindings, + get_view_info, + modifies_arguments, + use_derived, +) + + +if TYPE_CHECKING: + from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo + + +FUNCTION_DECLARATION = CodeTemplate( + """\ +#define ${uppercase_op}_AVAILABLE +struct ${op} : public ${superclass} { + ${op}(${constructor_args}) ${initializer_list} + {} + virtual ~${op}() override = default; + virtual std::vector get_symints() const override; + virtual size_t num_symints() const override; + virtual std::vector get_tensors() const override; + virtual size_t num_tensors() const override; + virtual at::Tensor operator()(const at::Tensor&) const override; + virtual std::unique_ptr clone_and_set( + std::optional> = ::std::nullopt, + std::optional> = ::std::nullopt) const override; + +protected: + virtual void set_symints(std::vector) override; + virtual void set_tensors(std::vector) override; + +private: + ${state} +}; + +""" +) + +FUNCTION_DEFINITION = CodeTemplate( + """\ +std::vector ${op}::get_symints() const { + ${get_symints} +} + +size_t ${op}::num_symints() const { + return static_cast(${num_symints}); +} + +void ${op}::set_symints(std::vector ${symints_vec}) { + TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); + ${set_symints} +} + +std::vector ${op}::get_tensors() const { + ${get_tensors} +} + +size_t ${op}::num_tensors() const { + return static_cast(${num_tensors}); +} + +void ${op}::set_tensors(std::vector ${tensors_vec}) { + TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); + ${set_tensors} +} + +at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { + return ${op_call}; +} + +std::unique_ptr ${op}::clone_and_set( + std::optional> ${symints_vec}, + std::optional> ${tensors_vec}) const { + auto output = std::make_unique<${op}>(${clone_args}); + if (${symints_vec}.has_value()) { + output->set_symints(std::move(*(${symints_vec}))); + } + if (${tensors_vec}.has_value()) { + output->set_tensors(std::move(*(${tensors_vec}))); + } + return output; +} + +""" +) + + +# e.g. as_strided -> AsStridedViewFunc for camel case or +# as_strided_view_func otherwise +def view_func_name( + f: NativeFunction, include_namespace: bool = False, camel_case: bool = True +) -> str: + name = f.func.name.unambiguous_name() + view_func_name = f"{name.replace('.', '_')}_view_func" + if camel_case: + is_private = view_func_name.startswith("_") + view_func_name = "".join( + [p.title() for p in view_func_name.replace(".", "_").split("_")] + ) + if is_private: + # put the leading underscore back in + view_func_name = f"_{view_func_name}" + namespace = "torch::autograd::generated::" if include_namespace else "" + return f"{namespace}{view_func_name}" + + +def is_symint_or_tensor(arg: Argument) -> bool: + return arg.type.is_tensor_like() or arg.type.is_symint_like() + + +def remove_const_ref(binding: Binding) -> Binding: + return Binding( + name=binding.name, + nctype=binding.nctype.remove_const_ref(), + argument=binding.argument, + default=binding.default, + ) + + +def returns_multi_tensor(fn: NativeFunction) -> bool: + returns = fn.func.returns + assert len(returns) == 1 + returns_list_like = returns[0].type.is_list_like() is not None + returns_tensor_like = returns[0].type.is_tensor_like() + return returns_list_like and returns_tensor_like + + +# Generates strings with logic for getting / setting state of a particular type. +# +# Args: +# bindings (list): List of state bindings of interest (may be empty) +# state_vec_type (NamedCType): Type of vector to either return or copy from +# +# Returns: +# tuple: (list of getter logic strings, list of setter logic strings, string +# with num items expression) +def generate_state_getter_setter( + bindings: list[Binding], + state_vec_type: NamedCType, +) -> tuple[list[str], list[str], str]: + getter_logic = [] + setter_logic = [] + + state_vec = state_vec_type.name + getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") + if len(bindings) > 0: + setter_logic.append("auto i = 0;") + + num_exprs = [] + for i, b in enumerate(bindings): + assert isinstance(b.argument, Argument) + if b.argument.type.is_list_like(): + # Handle list-likes. + num_expr = f"{b.name}.size()" + num_exprs.append(num_expr) + getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" + setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" + elif isinstance(b.argument.type, OptionalType): + # Handle optionals. + num_expr = f"({b.name}.has_value() ? 1 : 0)" + num_exprs.append(num_expr) + conditional = f"if({b.name}.has_value())" + getter = ( + f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" + ) + setter = f"{conditional} {b.name} = {state_vec}[i];" + else: + num_expr = "1" + num_exprs.append(num_expr) + getter = f"{state_vec}.push_back({b.name});" + setter = f"{b.name} = {state_vec}[i];" + + getter_logic.append(getter) + setter_logic.append(setter) + if i < len(bindings) - 1: + setter_logic.append(f"i += {num_expr};") + + # Reserve / assert based on the total number of items expression. + num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) + if len(bindings) > 0: + getter_logic.insert(1, f"{state_vec}.reserve({num_items});") + + getter_logic.append(f"return {state_vec};") + + return getter_logic, setter_logic, num_items + + +def process_function(fn: NativeFunction, template: CodeTemplate) -> str: + bindings = extract_bindings(fn) + non_self_bindings = [b for b in bindings if b.name != "self"] + + non_self_args = fn.func.arguments.flat_all[1:] + non_self_value_bindings = [ + dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args + ] + + # Generate constructor / clone args for the generated struct. + constructor_args = [b.defn() for b in non_self_bindings] + clone_args = [b.name for b in non_self_bindings] + + # Generate state variable declarations for the generated struct. + state_variables = [ + f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings + ] + + # Generate initializer list expressions for the generated struct. + # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as + # vectors. + init_exprs = translate( + non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True + ) + initializers = [] + for b, init_expr in zip(non_self_bindings, init_exprs): + name = b.nctype.name + assert isinstance(name, str) + initializers.append(f"{name}({init_expr.expr})") + + # Generate call to underlying view op + call_input_name = "input_base" + op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] + op_call = CALL_DISPATCH.substitute( + unambiguous_name=fn.func.name.unambiguous_name(), + unpacked_args=op_call_args, + ) + + # Multi-output views additionally require a view_idx for disambiguation. + if returns_multi_tensor(fn): + view_idx_name = "view_idx" + view_idx_typename = "int64_t" + view_idx_decl = f"{view_idx_typename} {view_idx_name}" + constructor_args.append(view_idx_decl) + clone_args.append(view_idx_name) + state_variables.append(f"{view_idx_decl};") + initializers.append(f"{view_idx_name}({view_idx_name})") + op_call += f"[{view_idx_name}]" + + # Generate initializer list for the generated struct. + initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" + + # Generate getter / setter logic for any symints. + symint_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() + ] + symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) + get_symints, set_symints, num_symints = generate_state_getter_setter( + symint_bindings, symints_vec_type + ) + + # Generate getter / setter logic for any tensors. + tensor_bindings = [ + b + for b in non_self_bindings + if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() + ] + tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) + get_tensors, set_tensors, num_tensors = generate_state_getter_setter( + tensor_bindings, tensors_vec_type + ) + + return template.substitute( + op=view_func_name(fn), + uppercase_op=view_func_name(fn, camel_case=False).upper(), + superclass="torch::autograd::ViewFunc", + initializer_list=initializer_list, + state=state_variables, + constructor_args=constructor_args, + clone_args=clone_args, + symints_vec=symints_vec_type.name, + get_symints=get_symints, + set_symints=set_symints, + num_symints=num_symints, + tensors_vec=tensors_vec_type.name, + get_tensors=get_tensors, + set_tensors=set_tensors, + num_tensors=num_tensors, + call_input_name=call_input_name, + op_call=op_call, + ) + + +def gen_view_funcs( + out: str, + fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], + template_path: str, +) -> None: + # don't need the info parts, just the function + fns = [fn.func for fn in fns_with_infos if use_derived(fn)] + # only want out-of-place views + view_fns = [ + fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) + ] + + declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] + definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] + ops_headers = [f"#include " for fn in view_fns] + + file_basename = "ViewFuncs" + fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) + for suffix in [".h", ".cpp"]: + fname = file_basename + suffix + fm.write_with_template( + fname, + fname, + lambda: { + "generated_comment": "@" + + f"generated from {fm.template_dir_for_comments()}/{fname}", + "view_func_declarations": declarations, + "view_func_definitions": definitions, + "ops_headers": ops_headers, + }, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/load_derivatives.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/load_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..59669b42cd5d45643306f6fd83bf3adb73b6c288 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/packaged/autograd/load_derivatives.py @@ -0,0 +1,1025 @@ +# Parses derivatives.yaml into autograd functions +# +# Each autograd function is represented by `DifferentiabilityInfo` containing +# a list of `Derivative`. See `torchgen.api.autograd` for the data models. + +from __future__ import annotations + +import re +from collections import Counter, defaultdict +from typing import Any, TYPE_CHECKING + +import yaml + +from torchgen.api import cpp +from torchgen.api.autograd import ( + Derivative, + DifferentiabilityInfo, + ForwardDerivative, + SavedAttribute, +) +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + CppSignatureGroup, + layoutT, + longT, + NamedCType, + OptionalCType, + scalarTypeT, + SpecialArgName, + stringT, + symIntArrayRefT, + SymIntT, + tensorGeometryT, + tensorOptionsT, + typeAndSizeT, + VectorCType, +) +from torchgen.context import with_native_function +from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml +from torchgen.model import ( + AUTOGRAD_KEYS, + FunctionSchema, + NativeFunction, + NativeFunctionsViewGroup, + OperatorName, + SchemaKind, + Type, + Variant, +) +from torchgen.utils import concatMap, IDENT_REGEX, split_name_params +from torchgen.yaml_utils import YamlLoader + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +DerivativeRet = tuple[dict[FunctionSchema, dict[str, DifferentiabilityInfo]], set[str]] + +_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {} + +_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS) + + +# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op. +# Since every {view} and {view}_copy op shares the same derivative formula, +# we generate them here instead of duplicating them in the yaml. +# See Note [Codegen'd {view}_copy Operators] +def add_view_copy_derivatives( + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], + view_groups: list[NativeFunctionsViewGroup], +) -> None: + # Get the map from each view op's name to its corresponding view group + view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = { + g.view.func.name: g for g in view_groups + } + + view_infos = {} + + for info_dispatch_dict in infos.values(): + # maybe_view_group only needs to be calculated once per info_dispatch_dict + maybe_view_group = None + view_copy_differentiability_infos = {} + for dispatch_key, info in info_dispatch_dict.items(): + maybe_view_group = view_name_to_group.get(info.func.func.name, None) + if maybe_view_group is not None and maybe_view_group.view_copy is not None: + view_copy_info = info.create_view_copy_from_view_derivative( + maybe_view_group + ) + if view_copy_info is not None: + fn_schema = view_copy_info.func.func + view_copy_differentiability_infos[dispatch_key] = view_copy_info + else: + break + # prefer manually-defined derivatives if any + # pyrefly: ignore [unbound-name] + if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: + # pyrefly: ignore [unbound-name] + assert fn_schema is not None + # pyrefly: ignore [unbound-name] + view_infos[fn_schema] = view_copy_differentiability_infos + + infos.update(view_infos) + + +def load_derivatives( + derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str +) -> DerivativeRet: + # Do some caching as this is a deterministic function + global _GLOBAL_LOAD_DERIVATIVE_CACHE + key = (derivatives_yaml_path, native_yaml_path) + if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: + with open(derivatives_yaml_path) as f: + definitions = yaml.load(f, Loader=YamlLoader) + + funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions + # From the parsed native functions, separate out the (generated) view_copy functions, + # so we can generate derivatives for them separately. + native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs) + native_functions = concatMap( + lambda g: [g] + if isinstance(g, NativeFunction) + else list(g.functions(include_copy=True)), + native_functions_with_view_groups, + ) + view_groups = [ + g + for g in native_functions_with_view_groups + if isinstance(g, NativeFunctionsViewGroup) + ] + + # What's the difference between function schema v.s. signature? + # function schema is the complete declaration including mutability annotation / default value and etc. + # signature is the canonical schema for a group of functions (in-place/out/functional variants) + # that are semantically related. + functions_by_signature: dict[FunctionSchema, list[NativeFunction]] = ( + defaultdict(list) + ) + functions_by_schema: dict[str, NativeFunction] = {} + for function in native_functions: + functions_by_signature[function.func.signature()].append(function) + assert str(function.func) not in functions_by_schema + functions_by_schema[str(function.func)] = function + + # Keep track of how many of which ops we've seen so we can + # disambiguate them with a numeric suffix. + op_counter = Counter[str]() + + # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos + # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info + # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema + infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {} + used_dispatch_keys: set[str] = set() + for defn_dict in definitions: + # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded. + if "dispatch" not in defn_dict: + specification = defn_dict.pop("name") + output_differentiability = defn_dict.pop( + "output_differentiability", None + ) + defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}} + if output_differentiability: + defn_dict["output_differentiability"] = output_differentiability + name, per_dispatch_diffinfos = create_differentiability_info( + defn_dict, + functions_by_signature, + functions_by_schema, + op_counter, + used_dispatch_keys, + ) + infos[name] = per_dispatch_diffinfos + + add_view_copy_derivatives(infos, view_groups) + + # cache both loaded infos as well a a set of all the dispatch_keys/aliases + # that appear in derivatives.yaml. used_dispatch_keys is useful for generating + # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used + _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys + + return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] + + +# TODO: Why is this going through CppSignatureGroup, that doesn't make sense... +@with_native_function +def cpp_arguments(f: NativeFunction) -> Sequence[Binding]: + sigs = CppSignatureGroup.from_native_function(f, method=False) + if sigs.symint_signature is not None: + return sigs.symint_signature.arguments() + else: + return sigs.signature.arguments() + + +def create_derivative( + f: NativeFunction, + formula: str, + var_names: tuple[str, ...], + available_named_gradients: Sequence[str], +) -> Derivative: + original_formula = formula + arguments: list[NamedCType] = [ + a.nctype.remove_const_ref() for a in cpp_arguments(f) + ] + + return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) + return_types = tuple( + cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns + ) + + named_returns = [ + NamedCType(name, type) for name, type in zip(return_names, return_types) + ] + + formula, saved_inputs = saved_variables(formula, arguments, var_names) + formula, saved_outputs = saved_variables(formula, named_returns, var_names) + + used_named_gradients = { + name + for name in available_named_gradients + if re.search(IDENT_REGEX.format(name), formula) + } + + # Check that the referenced derivatives in the formula are in bounds + for i in used_gradient_indices(formula): + if i >= len(f.func.returns): + raise RuntimeError( + f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " + f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." + ) + + return Derivative( + formula=formula, + original_formula=original_formula, + var_names=var_names, + saved_inputs=saved_inputs, + saved_outputs=saved_outputs, + named_gradients=used_named_gradients, + ) + + +def create_forward_derivative( + f: NativeFunction, formula: str, names: tuple[str, ...] +) -> ForwardDerivative: + var_names = names + var_types: tuple[Type, ...] | None = None + for r in f.func.returns: + if r.name in var_names: + if var_types is None: + var_types = () + var_types = var_types + (r.type,) + + # Handle default return names + if var_types is None: + if var_names == ("result",): + assert len(f.func.returns) == 1 + var_types = (f.func.returns[0].type,) + else: + for var_name in var_names: + res = re.findall(r"^result(\d+)$", var_name) + if len(res) == 1: + if var_types is None: + var_types = () + arg_idx = int(res[0]) + var_types = var_types + (f.func.returns[arg_idx].type,) + + assert var_types is not None, "No matching output for forward derivative definition" + return ForwardDerivative( + formula=formula, + var_names=var_names, + var_types=var_types, + required_inputs_fw_grad=None, + required_inputs_primal=None, + required_original_self_value=False, + is_reusing_outplace_formula=False, + ) + + +def postprocess_forward_derivatives( + f: NativeFunction, + defn_name: str, + all_arg_names: list[str], + derivatives: list[Derivative], + forward_derivatives: list[ForwardDerivative], + args_with_derivatives: Sequence[Binding], +) -> list[ForwardDerivative]: + def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]: + is_foreach = f.func.name.name.base.startswith("_foreach_") + required_inputs = set() + for arg in args_with_derivatives: + if ( + arg.type in ("at::TensorList", "const at::ITensorListRef &") + and not is_foreach + ): + # The functions taking TensorList handle everything internally + continue + arg_name = arg.name + + found = re.search(IDENT_REGEX.format(arg_name), formula) + if found: + raise RuntimeError( + f"The forward formula for {defn_name} is using the base name of the {arg_name} " + f"argument which is ambiguous. You should use {arg_name}_p to access the primal " + f"value and {arg_name}_t to access the tangent." + ) + + found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) + if found: + required_inputs.add(arg_name) + + return tuple(required_inputs) + + updated_derivatives: list[ForwardDerivative] = [] + + for defn in forward_derivatives: + formula = defn.formula + required_inputs_tangent = find_required_inputs(formula, "_t") + if formula == "auto_element_wise": + assert f.func.kind() != SchemaKind.inplace, ( + f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" + ) + if ( + (not len(args_with_derivatives) == 1) + or len(forward_derivatives) > 1 + or len(forward_derivatives[0].var_names) > 1 + ): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as element_wise but this only " + "works for functions with a single differentiable input and a " + "single differentiable output." + ) + if not len(derivatives) == 1: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as element_wise but it does not " + "defines the gradient formula for its argument which is required." + ) + # This transformation is based on the observation that for element-wise functions, the Jacobian + # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) + # For the complex case, we use hermitian transpose and get (v.conj() J).conj() + # So here we are going to reuse the backward formula and replace two things: + # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. + # 2) all usage of an original input "foo" with its primal value "foo_p". + # 3) conjugate the final result + # For example, for abs, the backward formula is: + # grad * self.sgn() + # And this function generates a forward formula that is: + # (self_t.conj() * self_p.sgn()).conj() + + backward_formula = derivatives[0].original_formula + input_name = args_with_derivatives[0].name + + # Do replacement 1) of the grad + def repl(m: Any) -> str: + return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" + + fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) + + # Do replacement 2) of the input variables + for arg in args_with_derivatives: + arg_name = arg.name + + def repl(m: Any) -> str: + return f"{m.group(1)}{arg_name}_p{m.group(2)}" + + fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) + + # Do the final conjugate 3) + fw_formula = f"({fw_formula}).conj()" + + # Since there is a single differentiable inputs and we necessarily need its tangent we can + # simply require all differentiable input's tangent. + required_inputs_tangent = tuple(all_arg_names) + formula = fw_formula + elif formula == "auto_linear": + if ( + len(forward_derivatives) > 1 + or len(forward_derivatives[0].var_names) > 1 + ): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml defines the " + "forward definition of gradient as linear but this only works " + "for functions with a single differentiable output." + ) + # This transformation is based on the observation that linear functions can be written as: + # y = f(x) = A * x + # For some matrix A and the Jacobian of the function f is also A. + # So doing J * v = A * v = f(v). + # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. + # We do this by calling the forward again by replacing any occurrence of the differentiable + # input "foo" by it's tangent "foo_t". + # Note that multiple inputs are not a problem as long as the function is truly linear wrt to + # the vector where all the differentiable inputs are stacked. + + diff_arg_names = [arg.name for arg in args_with_derivatives] + assert len(diff_arg_names) > 0 + + # Do replacement of input variables + new_args = [] + for arg_name in all_arg_names: + if arg_name in diff_arg_names: + arg_name = arg_name + "_t" + # pyrefly: ignore [bad-argument-type] + new_args.append(arg_name) + + # TODO we are trolling + if f.func.has_symint(): + defn_name += "_symint" + + # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. + if Variant.function in f.variants: + fw_formula = f"at::{defn_name}({', '.join(new_args)})" + else: + assert Variant.method in f.variants + fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})" + + # All of the input tangents are always used so all of them are required here. + required_inputs_tangent = tuple(diff_arg_names) + formula = fw_formula + + # At this point, the formula is final and is not modified anymore. + + # During forward formula, we use the primal instead of the input Tensors. + # This call inspects the formula to find for which input's primal are used. + required_inputs_primal = find_required_inputs(formula, "_p") + + updated_derivatives.append( + ForwardDerivative( + formula=formula, + var_names=defn.var_names, + var_types=defn.var_types, + required_inputs_fw_grad=required_inputs_tangent, + required_inputs_primal=required_inputs_primal, + required_original_self_value=False, + is_reusing_outplace_formula=False, + ) + ) + + return updated_derivatives + + +def is_forward_derivative_definition( + all_arg_names: list[str], names: tuple[str, ...] +) -> bool: + for name in names: + return name not in all_arg_names + raise RuntimeError("Expected `names` to be non-empty") + + +def create_differentiability_info( + defn_dict: dict[Any, Any], + functions_by_signature: dict[FunctionSchema, list[NativeFunction]], + functions_by_schema: dict[str, NativeFunction], + op_counter: Counter[str], + used_dispatch_keys: set[str], +) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]: + """Processes a single entry `defn` in derivatives.yaml""" + + def canonical_function( + functions: Sequence[NativeFunction], name: str + ) -> NativeFunction: + for f in functions: + if ( + not f.func.is_functional_fn() + and not f.func.is_out_fn() + and name == str(f.func.name.name) + ): + return f + # some functions only have in-place variants + assert name + "_" == cpp.name(functions[0].func) + return functions[0] + + def split_names(raw_names: str) -> tuple[str, ...]: + """Given "foo, bar", return ["foo", "bar"].""" + return tuple(x.strip() for x in raw_names.split(",")) + + def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: + """ + Check for some subtle mistakes one might make when writing derivatives. + These mistakes will compile, but will be latent until a function is + used with double backwards. + """ + + uses_grad = False # true if any derivative uses "grad" + num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" + uses_named_grads = False # true if any derivative uses "grad_{name}" + used_grads_indices: list[int] = [] # which indices of grads are used + for d in derivatives: + formula = d.formula + uses_grad = uses_grad or bool( + re.findall(IDENT_REGEX.format("grad"), formula) + ) + num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula)) + uses_named_grads = uses_named_grads or bool(d.named_gradients) + used_grads_indices.extend(used_gradient_indices(formula)) + # This is a basic sanity check: the number of places we see + # "grads" should be no fewer than the number of indices we see + # inside "grads". They may not be equal because we may use + # "grads" without an index. + assert num_grads_uses >= len(used_grads_indices) + # Thus if the number is equal, every use of grads is also + # indexed. + only_used_grads_indices = num_grads_uses == len(used_grads_indices) + + if uses_grad and num_grads_uses > 0: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml illegally " + "mixes use of 'grad' and 'grads'. Consider replacing " + "occurrences of 'grad' with 'grads[0]'" + ) + + if only_used_grads_indices and set(used_grads_indices) == {0}: + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml solely " + "refers to 'grads[0]'. If the first output is indeed the " + "only differentiable output, replace 'grads[0]' with 'grad'; " + "otherwise, there is a likely error in your derivatives " + "declaration." + ) + + if uses_named_grads and (uses_grad or num_grads_uses > 0): + raise RuntimeError( + f"Derivative definition of {defn_name} in derivatives.yaml illegally " + 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' + "only one method for identifying gradients." + ) + + @with_native_function + def set_up_derivatives( + f: NativeFunction, + ) -> tuple[ + Sequence[Derivative], + Sequence[ForwardDerivative], + Sequence[Binding], + Sequence[str], + Sequence[str], + ]: + # Set up the derivative information + derivatives: list[Derivative] = [] + forward_derivatives: list[ForwardDerivative] = [] + non_differentiable_arg_names: list[str] = [] + args_with_derivatives_set: set[str] = set() + + all_arg_names = [a.name for a in cpp_arguments(f)] + all_ret_names = [ + r.name for r in f.func.returns + ] # only used for the assert below + # output_differentiability is captured from the enclosed + # scope. Don't modify it. + # + # If it is not present, then no output is explicitly + # undifferentiable. + # + # It may be present and shorter than the length of return + # values. If that's the case, any return value that does not + # have a corresponding entry is considered not differentiable. + differentiability = output_differentiability or [True] * len(f.func.returns) + # A return is available as a named gradient ... + available_named_gradients = [ + f"grad_{ret.name}" + for ret, differentiable in zip(f.func.returns, differentiability) + # if it has not been explicitly made undifferentiable + if differentiable + # and if it has a name + and ret.name is not None + # and if its type is differentiable + and ret.type.is_tensor_like() + ] + + for raw_names in sorted(defn.keys()): + formula = defn[raw_names] + names = split_names(raw_names) + + for name in names: + assert not (name in all_arg_names and name in all_ret_names), ( + f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " + f"expected '{name}' to not be both an input arg and named return. " + ) + + if is_forward_derivative_definition(all_arg_names, names): + forward_derivatives.append(create_forward_derivative(f, formula, names)) + else: + if formula.lower().strip() == "non_differentiable": + non_differentiable_arg_names += names + else: + derivative = create_derivative( + f, formula, names, available_named_gradients + ) + derivatives.append(derivative) + args_with_derivatives_set |= set(names) + + overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) + if overlap: + raise RuntimeError( + f"derivatives definition for {defn} have overlapped non_differentiable " + f"and differentiable variables: {overlap}" + ) + + # Next, let us determine the list of inputs in order. + # TODO: do we need eagerly calculate and save it here? Can it be derived + # from NativeFunction and `derivatives` on callsites instead? + args_with_derivatives = [ + a for a in cpp_arguments(f) if a.name in args_with_derivatives_set + ] + + # Postprocess forward derivatives definitions now that we know the differentiable arguments + forward_derivatives = postprocess_forward_derivatives( + f, + defn_name, + all_arg_names, + derivatives, + forward_derivatives, + args_with_derivatives, + ) + + # Test to see if the use of 'grads' makes sense. + check_grad_usage(defn_name, derivatives) + + return ( + derivatives, + forward_derivatives, + args_with_derivatives, + non_differentiable_arg_names, + available_named_gradients, + ) + + # NB: Removes 'name' from defn dictionary + specification = defn_dict.pop("name") + defn_name, _ = split_name_params(specification) + # NB: Removes 'output_differentiability' from defn dictionary + # `None` means all differentiable. + output_differentiability = defn_dict.pop("output_differentiability", None) + output_differentiability_conditions = None + if output_differentiability and any( + isinstance(diff, str) for diff in output_differentiability + ): + if len(output_differentiability) != 1: + raise RuntimeError( + f"Not supported: for {specification}," + f"output_differentiability must either be " + f"list[bool] or a list[str] where each str is a " + f"condition. In the case where it is a condition, " + f"we only support single-output functions. " + f"Please file us an issue. " + ) + output_differentiability_conditions = output_differentiability + output_differentiability = [True] + + schema_function = functions_by_schema.get(specification) + if not schema_function: + avail = "\n".join( + k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name + ) + raise RuntimeError( + f"could not find ATen function for schema: {specification} " + f". Available signatures:\n{avail}" + ) + + # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here + # to map in-place schemas to the out-of-place variants. + # TODO: maybe the logic to handle the legacy schema is no longer necessary? + signature = schema_function.func.signature() + functions = functions_by_signature[signature] + if len(functions) == 0: + avail = "\n".join( + str(k) + for k, v in functions_by_signature.items() + if cpp.name(k) == defn_name + ) + raise RuntimeError( + f"could not find ATen function for legacy signature: {signature} " + f"corresponding to schema {specification}. Please report a bug to PyTorch. " + f"Available signatures:\n{avail}" + ) + + canonical = canonical_function(functions, defn_name) + if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError( + f"Schema for {defn_name} has an argument named grad_input_mask, " + "but this name would be shadowed by our codegen. " + "Please use a different name in native_functions.yaml." + ) + + if "result" in (a.name for a in cpp_arguments(canonical)): + raise RuntimeError( + f"Schema for {defn_name} has an argument named result, " + "but this is only allowed for outputs." + "Please use a different name in native_functions.yaml." + ) + + diffinfo_dict = {} + for key, defn in defn_dict["dispatch"].items(): + if key != "Default" and key not in _VALID_AUTOGRAD_KEYS: + raise RuntimeError( + f"Invalid dispatch key {key} in derivatives.yaml for {specification}," + f" expected key to be one of {_VALID_AUTOGRAD_KEYS}" + ) + if key not in used_dispatch_keys: + used_dispatch_keys.add(key) + + ( + derivatives, + forward_derivatives, + args_with_derivatives, + non_differentiable_arg_names, + available_named_gradients, + ) = set_up_derivatives(canonical) + + used_named_gradients: set[str] = set() + for d in derivatives: + used_named_gradients |= d.named_gradients + + # only assign an op name if we are actually going to calculate a derivative + op = None + if args_with_derivatives: + op_prefix = _create_op_prefix(defn_name) + if key != "Default": + op_prefix = op_prefix + key + op = f"{op_prefix}{op_counter[op_prefix]}" + op_counter[op_prefix] += 1 + + diffinfo_dict[key] = DifferentiabilityInfo( + name=defn_name, + func=canonical, + op=op, + derivatives=derivatives, + forward_derivatives=forward_derivatives, + all_saved_inputs=dedup_vars( + [v for d in derivatives for v in d.saved_inputs] + ), + all_saved_outputs=dedup_vars( + [v for d in derivatives for v in d.saved_outputs] + ), + available_named_gradients=available_named_gradients, + used_named_gradients=used_named_gradients, + args_with_derivatives=args_with_derivatives, + non_differentiable_arg_names=non_differentiable_arg_names, + output_differentiability=output_differentiability, + output_differentiability_conditions=output_differentiability_conditions, + ) + + return canonical.func, diffinfo_dict + + +GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]" + + +def used_gradient_indices(formula: str) -> list[int]: + """Determine a list of gradient indices (the i in grads[i]) that + are used by the formula. + + >>> used_gradient_indices("foo(grads[0], grads[1])") + [0, 1] + """ + return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)] + + +def saved_variables( + formula: str, + nctypes: list[NamedCType], + var_names: tuple[str, ...], +) -> tuple[str, tuple[SavedAttribute, ...]]: + def stride_expr(name: str) -> str: + assert var_names == (name,), ( + 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' + 'that ".strides()" is being called on.' + ) + return f'strides_or_error({name}, "{name}")' + + REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [ + # replace self.sym_sizes() with self_sym_sizes + ( + r"{}.sym_sizes\(\)", + { + "suffix": "_sym_sizes", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + }, + ), + # replace self->sym_sizes() with self_sym_sizes_opt + ( + r"{}->sym_sizes\(\)", + { + "suffix": "_sym_sizes_opt", + "nctype": lambda name: NamedCType( + name, OptionalCType(BaseCType(symIntArrayRefT)) + ), + "expr": lambda name: f"{name}.has_value() ? std::optional({name}->sym_sizes()) : std::nullopt", + }, + ), + # replace self.sym_blocksize() with self_sym_blocksize_opt + ( + r"{}.sym_blocksize\(\)", + { + "suffix": "_self_sym_blocksize_opt", + "nctype": lambda name: NamedCType( + name, OptionalCType(BaseCType(symIntArrayRefT)) + ), + "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})", + }, + ), + # replace self.options() with self_options + ( + r"{}.options\(\)", + { + "suffix": "_options", + "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), + }, + ), + # replace zeros_like(self) with self_info + ( + r"zeros_like\({}\)", + { + "suffix": "_info", + "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), + "expr": lambda name: name, # at save-time + "res": lambda name: name + "_info.zeros()", # at eval-time + }, + ), + # replace self.sym_size(2) with self_sym_size_2 + ( + r"{}.sym_size\((-?\w+)\)", + { + "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}", + "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), + }, + ), + # replace self.numel() with self_numel + ( + r"{}.numel\(\)", + { + "suffix": "_numel", + "nctype": lambda name: NamedCType(name, BaseCType(longT)), + }, + ), + # replace self.sym_numel() with self_sym_numel + ( + r"{}.sym_numel\(\)", + { + "suffix": "_sym_numel", + "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)), + }, + ), + # replace to_args_sizes(self) with self_args_sizes + ( + r"to_args_sizes\({}\)", + { + "suffix": "_args_sizes", + "nctype": lambda name: NamedCType( + name, VectorCType(VectorCType(BaseCType(longT))) + ), + }, + ), + # replace to_args_sizes_symint(self) with self_args_sizes + ( + r"to_args_sizes_symint\({}\)", + { + "suffix": "_args_sizes_symint", + "nctype": lambda name: NamedCType( + name, VectorCType(VectorCType(BaseCType(SymIntT))) + ), + }, + ), + # replace to_args_scalartypes(self) with self_args_scalartypes + ( + r"to_args_scalartypes\({}\)", + { + "suffix": "_args_scalartypes", + "nctype": lambda name: NamedCType( + name, VectorCType(BaseCType(scalarTypeT)) + ), + }, + ), + # replace TensorGeometry(self) with self_geometry + ( + r"TensorGeometry\({}\)", + { + "suffix": "_geometry", + "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), + }, + ), + ( + r"{}.scalar_type\(\)", + { + "suffix": "_scalar_type", + "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), + }, + ), + # replace self.dim() with self_dim + ( + r"{}.dim\(\)", + { + "suffix": "_dim", + "nctype": lambda name: NamedCType(name, BaseCType(longT)), + }, + ), + # replace self.sym_strides() with self_sym_strides + ( + r"{}.sym_strides\(\)", + { + "suffix": "_sym_strides", + "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), + "expr": stride_expr, + }, + ), + # replace self.layout() with self_layout + ( + r"{}.layout\(\)", + { + "suffix": "_layout", + "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), + }, + ), + # replace self.is_conj() with self_conjugate + ( + r"{}.is_conj\(\)", + { + "suffix": "_conjugate", + "nctype": lambda name: NamedCType(name, BaseCType(boolT)), + }, + ), + ] + + # find which arguments need to be saved + saved: list[SavedAttribute] = [] + + if ".sizes()" in formula or "->sizes()" in formula: + raise RuntimeError( + ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}" + ) + if re.search(r"\.size\([-]?\d+\)", formula) or re.search( + r"->size\([-]?\d+\)", formula + ): + raise RuntimeError( + ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}" + ) + if ".strides()" in formula or "->strides()" in formula: + raise RuntimeError( + ".strides() is not supported in derivative formulas. Instead, please use the SymInt version," + + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}" + ) + for nctype in nctypes: + # pyrefly: ignore [bad-assignment] + name = ( + nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name + ) + # First search the formula for expressions which can be evaluated + # when the autograd Function is created to avoid saving variables + for regex, info in REPLACEMENTS: + + def repl(m: re.Match[str]) -> str: + suffix: str = ( + # pyrefly: ignore [bad-assignment] + info["suffix"](m) if callable(info["suffix"]) else info["suffix"] + ) + expr: str = info["expr"](name) if "expr" in info else m.group(0) + saved.append( + SavedAttribute( + nctype=info["nctype"](name + suffix), + expr=expr, + ) + ) + if "res" in info: + replacement: str = info["res"](name) + return replacement + return name + suffix + + formula = re.sub(regex.format(name), repl, formula) + + # std::optional types stored in Backward nodes must be + # converted to std::optional before being passed into + # the backward function + if nctype.type == OptionalCType(BaseCType(stringT)): + formula = re.sub( + rf"\b{name}\b", + f"{name}.has_value() ? std::optional({name}.value()) : std::nullopt", + formula, + ) + + # Find any variables which remain in the formula and save them + if re.search(IDENT_REGEX.format(name), formula): + saved.append( + SavedAttribute( + nctype=nctype, + expr=name, + ) + ) + + return formula, tuple(saved) + + +def _create_op_prefix(name: str) -> str: + r"""Takes a native function name converts to an op prefix name. + + Note that the "name" parameter must be the native function name + without the optional variant suffix, so "add" instead of + "add.out". + + OP names correspond to classes, hence the change to title case. + + Example:: + + >>> _create_op_prefix("add") + 'AddBackward' + """ + camel_case = "".join([p.title() for p in name.split("_")]) + return (camel_case + "Backward").replace("ForwardBackward", "Backward") + + +def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: + seen: set[str] = set() + saved: list[SavedAttribute] = [] + for var in vars: + name = ( + var.nctype.name.name + if isinstance(var.nctype.name, SpecialArgName) + else var.nctype.name + ) + if name in seen: + continue + seen.add(name) + saved.append(var) + return saved diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50b69d3c8e6ff0786051e3f4960b136163b4055f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/operator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/operator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20e0ada5e89bdb047f52cdb748c6a8f26b628244 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/operator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/selector.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/selector.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e374492e46ca19e5bae9cf858dfbca73ad2a4d5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/selective_build/__pycache__/selector.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc41cfd70718c5e36e00690272f04fe7dea5a2a1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/config.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..552c52c305a0456d6015e111b058ca639b824820 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/config.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e219176002599343f0f1c17e7e0b5b4a0c65594 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b738dc3d50360c53f856757e1fbcde77a8d36407 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe129f9754dd83a136fbf9dc4478e04a2242efa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/config.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup + + +def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str: + if isinstance(g, NativeFunctionsGroup): + return str(g.functional.func.name.name.base) + else: + return str(g.view.root_name) + + +is_hand_written_ops_ = frozenset( + ( + "abs", + "add", + "addmm", + "all", + "any", + "argmin", + "bmm", + "clamp", + "clamp_min", + "cumsum", + "div", + "fmod", + "index_select", + "leaky_relu", + "linear", + "log", + "matmul", + "mul", + "narrow_copy", + "nonzero", + "pow", + "remainder", + "sigmoid", + "sign", + "sub", + "tanh", + "detach", + "expand_as", + "flatten", + "narrow", + "reshape_as", + "select", + "slice", + "softmax", + "split", + "squeeze", + "transpose", + "view", + "where", + ) +) + + +def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + name_base = func_name_base_str(g) + return name_base in is_hand_written_ops_ + + +def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None: + assert index == 0 or index == 1 + if op_name == "addr": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["vec1"] = "at::rand({6})" + arg_map["vec2"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["vec1"] = "at::rand({22})" + arg_map["vec2"] = "at::rand({22})" + return + if op_name == "mv": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["vec"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["vec"] = "at::rand({22})" + return + if op_name == "addbmm": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + else: + arg_map["self"] = "at::rand({22, 22})" + return + if op_name == "cross": + if index == 0: + arg_map["self"] = "at::rand({3, 3, 3})" + arg_map["other"] = "at::rand({3, 3, 3})" + else: + arg_map["self"] = "at::rand({22, 3, 22})" + arg_map["other"] = "at::rand({22, 3, 22})" + return + if op_name == "take": + if index == 0: + arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)" + else: + arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)" + return + if op_name == "take_along_dim": + if index == 0: + arg_map["indices"] = "at::argsort(self0, 1, true)" + else: + arg_map["indices"] = "at::argsort(self1, 1, true)" + return + if op_name == "masked_select": + if index == 0: + arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5" + else: + arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5" + return + if op_name == "orgqr": + if index == 0: + arg_map["input2"] = "at::rand({6, 6})" + else: + arg_map["input2"] = "at::rand({22, 22})" + return + if op_name == "ormqr": + if index == 0: + arg_map["input2"] = "at::rand({6, 6})" + else: + arg_map["input2"] = "at::rand({22, 22})" + return + if op_name == "quantile": + if index == 0: + arg_map["q"] = "at::rand({6})" + arg_map["interpolation"] = '"linear"' + else: + arg_map["q"] = "at::rand({22})" + arg_map["interpolation"] = '"linear"' + return + if op_name == "nanquantile": + if index == 0: + arg_map["q"] = "at::rand({6})" + arg_map["interpolation"] = '"linear"' + else: + arg_map["q"] = "at::rand({22})" + arg_map["interpolation"] = '"linear"' + return + if op_name == "multi_margin_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name == "multilabel_margin_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)" + return + if op_name == "nll_loss": + if index == 0: + arg_map["self"] = "at::rand({6, 6})" + arg_map["target"] = "at::randint(6, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22})" + arg_map["target"] = "at::randint(22, {22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name == "nll_loss2d": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6, 6})" + arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + else: + arg_map["self"] = "at::rand({22, 22, 22, 22})" + arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" + arg_map["weight"] = "at::rand({22})" + return + if op_name in ( + "fft_fft", + "fft_ifft", + "fft_rfft", + "fft_irfft", + "fft_hfft", + "fft_ihfft", + ): + arg_map["norm"] = '"forward"' + return + if op_name == "linalg_tensorinv": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6, 6})" + arg_map["ind"] = "2" + else: + arg_map["self"] = "at::rand({22, 22, 22, 22})" + arg_map["ind"] = "2" + return + if op_name == "addmv": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["mat"] = "at::rand({2, 2})" + arg_map["vec"] = "at::rand({2})" + else: + arg_map["self"] = "at::rand({35})" + arg_map["mat"] = "at::rand({35, 35})" + arg_map["vec"] = "at::rand({35})" + return + if op_name == "acosh": + if index == 0: + arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})" + else: + arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})" + return + if op_name == "adaptive_max_pool2d_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)" + arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)" + else: + arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)" + arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)" + return + if op_name == "adaptive_max_pool3d_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)" + arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)" + else: + arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)" + arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)" + arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)" + return + if op_name == "bitwise_left_shift": + if index == 0: + arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)" + arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)" + else: + arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)" + arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)" + return + if op_name == "bitwise_right_shift": + if index == 0: + arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)" + arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)" + else: + arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)" + arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)" + return + if op_name == "gather": + if index == 0: + arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)" + arg_map["dim"] = "1" + arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" + arg_map["sparse_grad"] = "false" + else: + arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)" + arg_map["dim"] = "1" + arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)" + arg_map["sparse_grad"] = "false" + return + if op_name == "gelu": + if index == 0: + arg_map["self"] = "at::rand({6, 6, 6})" + arg_map["approximate"] = '"tanh"' + else: + arg_map["self"] = "at::rand({22, 22, 22})" + arg_map["approximate"] = '"tanh"' + return + if op_name == "gelu_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({6, 6, 6})" + arg_map["self"] = "at::rand({6, 6, 6})" + arg_map["approximate"] = '"tanh"' + else: + arg_map["grad_output"] = "at::rand({22, 22, 22})" + arg_map["self"] = "at::rand({22, 22, 22})" + arg_map["approximate"] = '"tanh"' + return + if op_name == "index_add": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)" + arg_map["source"] = "at::rand({2})" + arg_map["alpha"] = "2" + else: + arg_map["self"] = "at::rand({16})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)" + arg_map["source"] = "at::rand({16})" + arg_map["alpha"] = "2" + return + if op_name == "index_copy": + if index == 0: + arg_map["self"] = "at::rand({2})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)" + arg_map["source"] = "at::rand({2})" + else: + arg_map["self"] = "at::rand({32})" + arg_map["dim"] = "0" + arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)" + arg_map["source"] = "at::rand({32})" + return + if op_name == "linalg_cross": + if index == 0: + arg_map["self"] = "at::rand({6, 3, 6})" + arg_map["other"] = "at::rand({6, 3, 6})" + arg_map["dim"] = "1" + else: + arg_map["self"] = "at::rand({22, 3, 22})" + arg_map["other"] = "at::rand({22, 3, 22})" + arg_map["dim"] = "1" + return + if op_name == "nll_loss_backward": + if index == 0: + arg_map["grad_output"] = "at::rand({})" + arg_map["self"] = "at::rand({6})" + arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)" + arg_map["weight"] = "at::rand({6})" + arg_map["reduction"] = "1" + arg_map["ignore_index"] = "1" + arg_map["total_weight"] = "at::rand({})" + else: + arg_map["grad_output"] = "at::rand({})" + arg_map["self"] = "at::rand({36})" + arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)" + arg_map["weight"] = "at::rand({36})" + arg_map["reduction"] = "1" + arg_map["ignore_index"] = "1" + arg_map["total_weight"] = "at::rand({})" + return + if op_name in ["scatter", "scatter_add", "_scatter_reduce"]: + if index == 0: + arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" + arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)" + arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)" + else: + arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" + arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)" + arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)" + if "reduce" in arg_map: + arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"' + return + if op_name == "scatter_reduce": + arg_map["reduce"] = '"mean"' + if index == 0: + arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)" + else: + arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)" + return + if op_name == "special_zeta": + if index == 0: + arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" + arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})" + else: + arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" + arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})" + return + if op_name == "_convert_indices_from_csr_to_coo": + if index == 0: + arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)" + arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)" + arg_map["out_int32"] = "false" + else: + arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)" + arg_map["col_indices"] = ( + "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)" + ) + arg_map["out_int32"] = "false" + return + if op_name == "_convert_indices_from_coo_to_csr": + if index == 0: + arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)" + arg_map["size"] = "10" + arg_map["out_int32"] = "false" + else: + arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)" + arg_map["size"] = "24" + arg_map["out_int32"] = "false" + return + if op_name in ("diagonal", "linalg_diagonal"): + arg_map["offset"] = "0" + arg_map["dim1"] = "2" + arg_map["dim2"] = "1" + return diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d6909bc4d7f67fc13fb9f61e00f4709a4ff5ad4e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +import argparse +import itertools +import os +from typing import TYPE_CHECKING, TypeVar + +from libfb.py.log import set_simple_logging # type: ignore[import] + +from torchgen import gen +from torchgen.context import native_function_manager +from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup +from torchgen.static_runtime import config, generator + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +# Given a list of `grouped_native_functions` sorted by their op names, return a list of +# lists each of which groups ops that share the base name. For example, `mean` and +# `mean.dim` are grouped together by this function. + +NativeGroupT = TypeVar( + "NativeGroupT", + bound=NativeFunctionsGroup | NativeFunctionsViewGroup, +) + + +def group_functions_by_op_name( + grouped_native_functions: Sequence[NativeGroupT], +) -> Sequence[Sequence[NativeGroupT]]: + if not grouped_native_functions: + return [] + groups = [] + + def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + with native_function_manager(g): + return generator.is_supported(g) + + eligible_ops = (g for g in grouped_native_functions if is_supported(g)) + groups = [ + list(group) + for k, group in ( + itertools.groupby( + eligible_ops, + key=config.func_name_base_str, + ) + ) + ] + + return groups + + +def clang_format(cpp_file_path: str) -> None: + import subprocess + + subprocess.check_call(["clang-format", "-i", cpp_file_path]) + + +def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None: + code = "\n".join(cpp_ops) + generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN +// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch {{ +namespace jit {{ + +{code} + +}} // namespace jit +}} // namespace torch +""" + with open(file_path, "w") as f: + f.write(generated) + clang_format(file_path) + + +def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None: + code = "\n".join(cpp_ops) + generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN +// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py +#include +#include +#include + +#include "test_utils.h" + +using namespace caffe2; +using namespace torch; +using namespace torch::jit; +using namespace torch::jit::test; +using c10::IValue; + +{code} + +""" + with open(file_path, "w") as f: + f.write(generated) + clang_format(file_path) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate ATen source files") + parser.add_argument( + "-s", + "--source-path", + help="path to source directory for ATen", + default="caffe2/aten/src/ATen", + ) + parser.add_argument( + "-p", + "--generated-ops-cpp-path", + help="path to directory to generate op dispatcher .cpp file", + default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp", + ) + parser.add_argument( + "-t", + "--generated-ops-test-cpp-path", + help="path to directory to generate op dispatcher .cpp file", + default="caffe2/benchmarks/static_runtime/test_generated_ops.cc", + ) + options = parser.parse_args() + native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") + tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") + parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path) + native_functions, backend_indices = ( + parsed_yaml.native_functions, + parsed_yaml.backend_indices, + ) + + op_generator = generator.GenOpDispatcher() + test_case_generator = generator.GenOpTestCase() + + native_functions_groups = [ + g + for g in gen.get_grouped_native_functions(native_functions) + if isinstance(g, NativeFunctionsGroup) + ] + + supported_functions_groups = group_functions_by_op_name(native_functions_groups) + + out_variant_op_result = [ + op_generator.out_variant(groups, backend_indices[DispatchKey.CPU]) + for groups in supported_functions_groups + ] + out_variant_test_result = [ + test_case_generator.out_variant(groups) for groups in supported_functions_groups + ] + + native_functions_view_groups = [ + g + for g in gen.get_grouped_by_view_native_functions(native_functions) + if isinstance(g, NativeFunctionsViewGroup) + ] + + supported_functions_view_groups = group_functions_by_op_name( + native_functions_view_groups + ) + + view_op_result = [ + op_generator.view(groups, backend_indices[DispatchKey.CPU]) + for groups in supported_functions_view_groups + ] + view_test_result = [ + test_case_generator.view(groups) for groups in supported_functions_view_groups + ] + + op_result = out_variant_op_result + ["\n\n"] + view_op_result + test_result = out_variant_test_result + ["\n\n"] + view_test_result + + write_cpp(op_result, options.generated_ops_cpp_path) + write_test_cpp(test_result, options.generated_ops_test_cpp_path) + + print( + f"\ntotal grouped native ops: {len(gen.get_grouped_native_functions(native_functions)):d}" + ) + + print(f"grouped native ops with out variant: {len(native_functions_groups):d}") + supported_functions_num = sum(len(groups) for groups in supported_functions_groups) + print(f"generated functions groups with out variant: {supported_functions_num:d}") + + print(f"\nview grouped native ops: {len(native_functions_view_groups):d}") + supported_view_functions_num = sum( + len(groups) for groups in supported_functions_view_groups + ) + print(f"generated functions view groups: {supported_view_functions_num:d}") + + print( + f"\noverall generated : {supported_functions_num + supported_view_functions_num:d}" + ) + + +if __name__ == "__main__": + set_simple_logging(escape_newlines=False) + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/generator.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad2fd3c458892568429f86e5cd53c26982b38fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torchgen/static_runtime/generator.py @@ -0,0 +1,814 @@ +from __future__ import annotations + +import json +import logging +import math +from typing import TYPE_CHECKING + +import torchgen.api.cpp as cpp +from torchgen.context import native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + BaseTy, + BaseType, + FunctionSchema, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OptionalType, + SelfArgument, + TensorOptionsArguments, + Type, +) +from torchgen.static_runtime import config + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +logger: logging.Logger = logging.getLogger() + + +def has_alias( + arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments], +) -> bool: + for arg in arguments: + annotation = getattr(arg, "annotation", None) + if not annotation: + continue + alias_set = getattr(annotation, "alias_set", ()) + if alias_set: + return True + return False + + +BLOCKED_OPS = frozenset( + ( + # non cpu ops + "sparse_sampled_addmm", + "hspmm", + "linalg_svdvals", + # sparse ops + "sspaddmm", + "coalesce", + "_indices", + "indices", + "_values", + "values", + "crow_indices", + "col_indices", + # deprecated ops + "floor_divide", + "ger", + # buggy ops + "conj_physical", # P495807361 + "binary_cross_entropy", # P496394764 + "arccosh", + # uncommon ops + "cholesky", + "lu_solve", + "linalg_cholesky", + "linalg_householder_product", + "linalg_ldl_solve", + "_compute_linear_combination", + # training related ops + "_make_dual", + # cannot call directly + "_fw_primal", + # no documentation + "_index_reduce", + # TODO: these ones got added recently and need manual inspection + "_new_zeros_with_same_feature_meta", + "_conj_physical", + "binary_cross_entropy_with_logits", + "bincount", + "conv_tbc", + "copy", + "_copy_from", + "_copy_from_and_resize", + "count_nonzero", + "cudnn_affine_grid_generator", + "cudnn_affine_grid_generator_backward", + "cudnn_grid_sampler", + "diag_embed", + "embedding", + "embedding_dense_backward", + "_embedding_bag_dense_backward", + "_embedding_bag_per_sample_weights_backward", + "grid_sampler_2d", + "_grid_sampler_2d_cpu_fallback", + "grid_sampler_3d", + "isnan", + "mkldnn_linear", + "median", + "nanmedian", + "_sparse_sparse_matmul", + "batch_norm_backward_elemt", + "_euclidean_dist", + "pixel_shuffle", + "pixel_unshuffle", + "channel_shuffle", + "_reshape_nested_backward", + "relu", + "prelu", + "celu", + "slice_scatter", + "select_scatter", + "diagonal_scatter", + "sum", + "_mkldnn_transpose", + "_nested_tensor_from_mask", + "_nested_from_padded", + "_nested_tensor_size", + "_nested_from_padded_and_nested_example", + "_standard_gamma_grad", + "_dirichlet_grad", + "native_norm", + "_sparse_softmax", + "_sparse_softmax_backward_data", + "_sparse_log_softmax", + "_sparse_log_softmax_backward_data", + "zero", + "_sparse_addmm", + "sparse_mask", + "_sparse_mask_projection", + "_to_dense", + "_coalesce", + "_coalesced", + "copy_sparse_to_sparse", + "to_sparse", + "to_sparse_csr", + "to_sparse_csc", + "to_mkldnn", + "quantize_per_tensor_dynamic", + "quantize_per_channel", + "q_per_channel_scales", + "q_per_channel_zero_points", + "int_repr", + "_make_per_channel_quantized_tensor", + "set", + "lift", + "lift_fresh", + "lift_fresh_copy", + "masked_scatter", + "_masked_softmax", + "_masked_softmax_backward", + "put", + "index_reduce", + "trace", + "_cholesky_solve_helper", + "dist", + "max", + "_torch_cuda_cu_linker_symbol_op", + "glu_jvp", + "glu_backward_jvp", + "hardswish_backward", + "rrelu_with_noise_backward", + "mkldnn_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool2d_backward", + "_adaptive_avg_pool3d_backward", + "isinf", + "linalg_lu_solve", + "linalg_vecdot", + "linalg_matrix_exp", + "linalg_eigvalsh", + "_test_warn_in_autograd", + "_test_autograd_multiple_dispatch_view", + "_test_autograd_multiple_dispatch_view_copy", + "_segment_reduce", + "_segment_reduce_backward", + "_fw_primal_copy", + "_make_dual_copy", + "view_as_real_copy", + "view_as_complex_copy", + "_conj_copy", + "_neg_view_copy", + "diagonal_copy", + "detach_copy", + "squeeze_copy", + "t_copy", + "unsqueeze_copy", + "_indices_copy", + "_values_copy", + "indices_copy", + "values_copy", + "crow_indices_copy", + "col_indices_copy", + "ccol_indices", + "ccol_indices_copy", + "row_indices", + "row_indices_copy", + "unfold_copy", + "alias_copy", + "_triton_multi_head_attention", + "special_airy_ai", + "special_bessel_j0", + "special_bessel_j1", + "special_bessel_y0", + "special_bessel_y1", + "special_chebyshev_polynomial_t", + "special_chebyshev_polynomial_u", + "special_chebyshev_polynomial_v", + "special_chebyshev_polynomial_w", + "special_hermite_polynomial_h", + "special_hermite_polynomial_he", + "special_laguerre_polynomial_l", + "special_legendre_polynomial_p", + "special_modified_bessel_i0", + "special_modified_bessel_i1", + "special_modified_bessel_k0", + "special_modified_bessel_k1", + "special_scaled_modified_bessel_k0", + "special_scaled_modified_bessel_k1", + "special_shifted_chebyshev_polynomial_t", + "special_shifted_chebyshev_polynomial_u", + "special_shifted_chebyshev_polynomial_v", + "special_shifted_chebyshev_polynomial_w", + "special_spherical_bessel_j0", + "_foobar", + "_nested_tensor_strides", + "_nested_tensor_storage_offsets", + "_nested_get_values", # no CPU backend + "_nested_get_values_copy", # no CPU backend + "_nested_view_from_jagged", # testing needs to be patched + "_nested_view_from_jagged_copy", # testing needs to be patched + "_nested_view_from_buffer", # testing needs to be patched + "_nested_view_from_buffer_copy", # testing needs to be patched + "_int_mm", # testing needs to be patched + "_to_sparse_csc", # testing needs to be patched + "_to_sparse_csr", # testing needs to be patched + "segment_reduce", # testing needs to be patched + ) +) + + +def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: + base_op_name = "" + func = None + if isinstance(g, NativeFunctionsViewGroup): + base_op_name = g.view.root_name + func = g.view.func + else: + base_op_name = g.out.func.name.name.base + func = g.out.func + if config.is_hand_written(g): + logger.info("HAND WRITTEN: %s", base_op_name) + return False + if base_op_name in BLOCKED_OPS: + logger.info("BLOCKED: %s", base_op_name) + return False + for arg in func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func) + return False + + if isinstance(g, NativeFunctionsViewGroup): + # TODO: stop doing type tests by converting to C++ and then testing + # the string, just test the dang thing directly + if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): + # Returns a non-Tensor value. + logger.info("NON-TENSOR RET TYPE: %s", str(func)) + return False + return True + + # For out variant ops, we need to check the arguments of its functional func. + for arg in g.functional.func.schema_order_arguments(): + maybe_method = ivalue_type_conversion_method(arg.type) + if not maybe_method: + # Type converting is unsupported yet. + logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func) + return False + + if not g.structured: + # In case of unstructured op, we check if it has out variant implementation. + # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last + # parameter. + if ( + not hasattr(g, "out") + or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)") + or not str(func.name).endswith(".out") + ): + return False + # TODO: stop type testing by converting to C++ + if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): + logger.info("NON_TENSOR RET TYPE: %s", func) + return False + if has_alias(func.arguments.non_out): + # This op may create an alias of inputs. + logger.info("INPUTS ALIAS: %s", base_op_name) + return False + return True + + +def ivalue_type_conversion_method( + arg_type: BaseType | OptionalType | Type, +) -> tuple[bool, str] | None: + """ + Return the method call expression of `c10::ivalue' to convert its contained value to + the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, + this function returns ".toTensor()", so that it can be appended to the ivalue's + variable name to get the value of the expected type. + """ + type_conversion_methods = { + BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional()")), + BaseTy.int: ((False, "toInt()"), (False, "toOptional()")), + BaseTy.bool: ((False, "toBool()"), (False, "toOptional()")), + BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional()")), + BaseTy.ScalarType: ( + (False, "toScalarType()"), + (False, "toOptional()"), + ), + BaseTy.str: ( + (False, "toStringView()"), + (False, "toOptional()"), + (False, "toOptional<::std::string_view>()"), + ), + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + elif isinstance(arg_type, OptionalType): + if not isinstance(arg_type.elem, BaseType): + # ListType is currently unsupported. + return None + base_ty_object = arg_type.elem.name + else: + return None + + if base_ty_object not in type_conversion_methods: + return None + methods = type_conversion_methods[base_ty_object] + if isinstance(arg_type, BaseType): + return methods[0] + return methods[1] + + +should_use_int_tensor_ops_ = frozenset( + ( + "bitwise_not", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "gcd", + "lcm", + "scatter", + "gather", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + ) +) +should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj")) + + +def should_use_int_tensor(op_name: str) -> bool: + return op_name in should_use_int_tensor_ops_ + + +def should_use_complex_tensor(op_name: str) -> bool: + return op_name in should_use_complex_tensor_ops_ + + +test_tensor_dim_ops_1_ = frozenset( + ( + "addmv", + "index_add", + "_convert_indices_from_coo_to_csr", + "_convert_indices_from_csr_to_coo", + "nll_loss_backward", + "dot", + "vdot", + "outer", + "ger", + ) +) +test_tensor_dim_ops_2_ = frozenset( + ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t") +) + + +def test_tensor_dim(op_name: str) -> int: + if op_name in test_tensor_dim_ops_1_: + return 1 + if op_name in test_tensor_dim_ops_2_: + return 2 + return 3 + + +test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' +test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string) + + +def test_tensor_shape(op_name: str) -> str: + if op_name in test_tensor_shape_json: + return test_tensor_shape_json[op_name] + else: + return "" + + +def test_value_expression( + arg_type: BaseType | OptionalType | Type, index: int, op_name: str +) -> str: + tensor_size_ex = test_tensor_shape(op_name) + if tensor_size_ex == "": + num_tensors = 16 if index == 0 else 64 + num_dim = test_tensor_dim(op_name) + size_per_dim = math.ceil(num_tensors / float(num_dim)) + size_per_dim += size_per_dim % 2 + tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim)) + if should_use_int_tensor(op_name): + tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)" + elif should_use_complex_tensor(op_name): + tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)" + else: + tensor_expression = f"at::rand({tensor_size_ex})" + + value_expressions = { + BaseTy.Tensor: tensor_expression, + BaseTy.int: "1", + BaseTy.bool: "false", + BaseTy.Scalar: "2", + BaseTy.ScalarType: "at::ScalarType::Float", + BaseTy.str: '"floor"', + } + + base_ty_object = None + if isinstance(arg_type, BaseType): + base_ty_object = arg_type.name + else: + assert isinstance(arg_type, OptionalType) and isinstance( + arg_type.elem, BaseType + ) + base_ty_object = arg_type.elem.name + assert base_ty_object in value_expressions, "not expected type" + value_expression = value_expressions[base_ty_object] + return value_expression + + +def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + schema_name = schema.name.name.base + arg_map = {} + for arg in schema.schema_order_arguments(): + test_value_exp = test_value_expression(arg.type, index, schema_name) + arg_map[arg.name] = test_value_exp + config.override_test_values(arg_map, schema_name, index) + arg_populations = [] + for arg_name, arg_value in arg_map.items(): + arg_populations.append(f"auto {arg_name}{index} = {arg_value}") + return ";\n ".join(arg_populations) + ";" + + +def generate_test_value_names(schema: FunctionSchema, index: int) -> str: + assert not schema.is_out_fn() + return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments()) + + +generate_test_ir_arguments_base_ty_to_type_str_ = { + BaseTy.Tensor: "Tensor", + BaseTy.int: "int", + BaseTy.float: "float", + BaseTy.str: "str", + BaseTy.Scalar: "int", + BaseTy.ScalarType: "int", + BaseTy.bool: "bool", +} + + +def generate_test_ir_arguments( + schema: FunctionSchema, +) -> list[tuple[str, str | None]]: + def ir_argument(arg: Argument) -> tuple[str, str | None]: + t = arg.type + add_optional = False + if isinstance(t, OptionalType): + t = t.elem + add_optional = True + assert isinstance(t, BaseType) + type_str = None + if t.name in generate_test_ir_arguments_base_ty_to_type_str_: + type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] + if type_str and add_optional: + type_str = f"{type_str}?" + return ("%" + arg.name, type_str) + + return [ir_argument(arg) for arg in schema.schema_order_arguments()] + + +def generate_arg_extraction(schema: FunctionSchema) -> str: + arg_populations = [] + for i, arg in enumerate(schema.schema_order_arguments()): + maybe_method = ivalue_type_conversion_method(arg.type) + assert maybe_method + is_reference, type_conversion_method = maybe_method + reference = "&" if is_reference else "" + arg_populations.append( + f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" + ) + return ";\n ".join(arg_populations) + ";" + + +def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.functional) + if g.structured or kernel is None: + return cpp.name(g.functional.func) + return kernel.kernel + + +def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: + kernel = backend_index.get_kernel(g.out) + if g.structured or kernel is None: + return cpp.name(g.out.func) + return kernel.kernel + + +def generate_non_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.functional.func + assert not schema.is_out_fn() + kernel_name = get_kernel_name(g, backend_index) + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "cpu" if g.structured else "native" + return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})" + + +def generate_call_to_view_ops( + g: NativeFunctionsViewGroup, backend_index: BackendIndex +) -> str: + schema = g.view.func + kernel_name = cpp.name(schema) + kernel = backend_index.get_kernel(g.view) + if kernel: + kernel_name = kernel.kernel + arg_names = (arg.name for arg in schema.schema_order_arguments()) + namespace_name = "native" + return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})" + + +def generate_out_variant_call( + g: NativeFunctionsGroup, backend_index: BackendIndex +) -> str: + schema = g.out.func + assert schema.is_out_fn() + arg_names = [] + kernel_name = get_out_kernel_name(g, backend_index) + if g.structured: + # structured op starts with the output tensor argument. + arg_names = [out_arg.name for out_arg in schema.arguments.out] + else: + arg_names = [] + for arg in schema.arguments.non_out: + if isinstance(arg, SelfArgument): + arg_names.append(arg.argument.name) + else: + assert isinstance(arg, Argument) + arg_names.append(arg.name) + if not g.structured: + assert len(schema.arguments.out) == 1 + arg_names.append(schema.arguments.out[0].name) + cpp_arg_names = ",".join(arg_names) + namespace_name = "cpu" if g.structured else "native" + return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" + + +no_memory_resize_ops = frozenset( + ( + "isin.Scalar_Tensor", + "index_add", + "dot", + "vdot", + "nuclear_norm", + "histc", + "l1_loss", + "multi_margin_loss", + "multilabel_margin_loss", + "nll_loss", + "nll_loss2d", + "prod", + ) +) + + +def should_check_resize(schema: FunctionSchema) -> bool: + schema_str = str(schema) + type_variant_op_name = schema_str[: schema_str.find("(")] + return type_variant_op_name not in no_memory_resize_ops + + +def op_name_from_group(g: NativeFunctionsGroup) -> str: + return g.functional.func.name.name.base + + +class GenOpDispatcher: + def out_variant( + self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = op_name_from_group(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}) +""" + return generated + + def view( + self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex + ) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_generator(g, backend_index) + generated_type_variants.append(generated_type_variant) + op_name = config.func_name_base_str(groups[0]) + body = "\n".join(generated_type_variants) + generated = f""" +REGISTER_NATIVE_OPERATOR_FUNCTOR( + aten::{op_name}, + aten_{op_name}, + [](Node* n) -> SROperator {{ + {body} + LogAndDumpSchema(n); + return nullptr; + }}); +""" + return generated + + def out_variant_op_generator( + self, g: NativeFunctionsGroup, backend_index: BackendIndex + ) -> str: + functional = g.functional + schema = str(functional.func) + populated_argument = generate_arg_extraction(g.functional.func) + functional_variant_call = generate_non_out_variant_call(g, backend_index) + assert len(g.out.func.arguments.out) == 1 + out_variable_name = str(g.out.func.arguments.out[0].name) + out_variant_call = generate_out_variant_call(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + if (p_node->Output(0).isNone()) {{ + p_node->Output(0) = {functional_variant_call}; + return; + }} + auto& {out_variable_name} = p_node->Output(0).toTensor(); + fastResizeToZero({out_variable_name}); + {out_variant_call}; + }}; + }}""" + return generated + + def view_op_generator( + self, g: NativeFunctionsViewGroup, backend_index: BackendIndex + ) -> str: + schema = str(g.view.func) + populated_argument = generate_arg_extraction(g.view.func) + functional_variant_call = generate_call_to_view_ops(g, backend_index) + generated = f""" + if (n->matches(torch::schema("aten::{schema}"))) {{ + return [](ProcessedNode* p_node) {{ + {populated_argument} + p_node->Output(0) = {functional_variant_call}; + }}; + }}""" + return generated + + +class GenOpTestCase: + def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsGroup) + generated_type_variant = self.out_variant_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: + if not groups: + return "" + generated_type_variants = [] + for g in groups: + with native_function_manager(g): + assert is_supported(g) + assert isinstance(g, NativeFunctionsViewGroup) + generated_type_variant = self.view_op_test_case_generator(g) + generated_type_variants.append(generated_type_variant) + return "\n".join(generated_type_variants) + + def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str: + schema = g.functional.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = op_name_from_group(g) + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + test_value_definitions2 = generate_test_value_definitions(schema, 1) + test_value_names2 = generate_test_value_names(schema, 1) + check_resize = "true" if should_check_resize(schema) else "false" + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + + {test_value_definitions2} + std::vector args2{{{test_value_names2}}}; + testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); + +}} +""" + return generated + + def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str: + schema = g.view.func + schema_str = str(schema) + assert schema_str.find("(") > 0 + type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") + op_name = g.view.root_name + assert type_variant_op_name.startswith(op_name) + + arg_types = generate_test_ir_arguments(schema) + arg_declarations = ", ".join( + ( + arg_name if arg_type is None else f"{arg_name}: {arg_type}" + for arg_name, arg_type in arg_types + ) + ) + arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) + assert ( + len(schema.returns) == 1 + and isinstance(schema.returns[0].type, BaseType) + and schema.returns[0].type.name is BaseTy.Tensor + ) + test_value_definitions = generate_test_value_definitions(schema, 0) + test_value_names = generate_test_value_names(schema, 0) + generated = f""" +TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ + const std::string script = R"IR( + graph({arg_declarations}): + %bias: None = prim::Constant() + %ret = aten::{op_name}({arg_names}) + %cloned = aten::clone(%ret, %bias) + return (%cloned) + )IR"; + + {test_value_definitions} + std::vector args{{{test_value_names}}}; + testStaticRuntime(script, args); +}} +""" + + return generated diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d8f64d459fc09ced442f755336be74527385324 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/custom_blocks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/custom_blocks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9a486be3cf024fd9fd83602769caf1ff674b04b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/custom_blocks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/diffusers_cli.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/diffusers_cli.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e0cf1195f09df07b09859985f88695edba9287d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/diffusers_cli.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/env.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/env.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee35fcce2154e5605470c80038a939ef1fbaa050 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/env.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/fp16_safetensors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/fp16_safetensors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f8596bf60ba3011526e2d5c3a314d42e6d3c4d7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/commands/__pycache__/fp16_safetensors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ece75dfeadc37e518e523e001fdf375c5d5b7b9f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b338d3173e12d478b6b6d6fd0e50650a0ab5a4c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f14afefa9052dd3116fde95e19296e4e394f869a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2c1196ffe965df4178bc857e01f5b749813c351 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/__pycache__/value_guided_sampling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/value_guided_sampling.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/value_guided_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c69d308ecc6894edd39557d7f16fc01914e4bd85 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,153 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import tqdm + +from ...models.unets.unet_1d import UNet1DModel +from ...pipelines import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler +from ...utils.torch_utils import randn_tensor + + +class ValueGuidedRLPipeline(DiffusionPipeline): + r""" + Pipeline for value-guided sampling from a diffusion model trained to predict sequences of states. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + value_function ([`UNet1DModel`]): + A specialized UNet for fine-tuning trajectories base on reward. + unet ([`UNet1DModel`]): + UNet architecture to denoise the encoded trajectories. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this + application is [`DDPMScheduler`]. + env (): + An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. + """ + + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + + self.register_modules(value_function=value_function, unet=unet, scheduler=scheduler, env=env) + + self.data = env.get_dataset() + self.means = {} + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: # noqa: E722 + pass + self.stds = {} + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: # noqa: E722 + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if isinstance(x_in, dict): + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + + # permute to match dimension for pre-trained models + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + + # TODO: verify deprecation of this kwarg + x = self.scheduler.step(prev_x, i, x)["prev_sample"] + + # apply conditions to the trajectory (set the initial state) + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = randn_tensor(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9c3d2acade5be19df87a4ca1a25ed62c8d96f3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eab3b319ce16ca64035431e396f863441664522 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance_mix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance_mix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f13c858c0789552f26efeca951a4a9224f7b51fd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/adaptive_projected_guidance_mix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..621e72edc246243f452b7f95e4a3168167d80001 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/auto_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0c02149967b23d8fa49c28cceeaa307b85c9098 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/classifier_free_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceed51fb1a36c966a92640600486d3fff65d469d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/classifier_free_zero_star_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..576cc319ee925cd910aeadb0155503b8184ec93f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/frequency_decoupled_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae29d583cb86948e3780a2a7c3a5f1b5b86dfec8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/guider_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/magnitude_aware_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/magnitude_aware_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a95dc9805d8ed3a81dcb95a859d1d9bb3a45c96c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/magnitude_aware_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..349d94e4ed8439c7136fefc289b4f992cda495f8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/perturbed_attention_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff418552ca7d9e60e170906df4be8345aa4cc253 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/skip_layer_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c810d0322222f59add84a4d00fb0d413d2df5e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/smoothed_energy_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fccc8b52fccba17738aad66d500c64d8a0af6b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/guiders/__pycache__/tangential_classifier_free_guidance.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ad7a72bfeb3aa67e54f6514de2c1e351e141553 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/_modeling_parallel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/_modeling_parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e386e023ca71df5e369abda5bf60e3bece65a31e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/_modeling_parallel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/activations.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/activations.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4337c3e5c1f7a44eb90dbd484341bfedefcf3d29 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/activations.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/adapter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/adapter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f299b5377968e44c1392e738adc7799345ec324 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/adapter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/attention.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8b269334ed83e45cf98c1f937d2f0c9ce7760cc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/attention.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/attention_flax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/attention_flax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f67a38370d253cc87627466e4455b16f8625042 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/attention_flax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/auto_model.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/auto_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb537bbe8e453b0386f3e18e4fa27aa88c22d1e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/auto_model.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/embeddings_flax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/embeddings_flax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c53b9fb9c4bf573d67b7b2c88a553d9a77d9560b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/embeddings_flax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/lora.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/lora.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e045628b7576d3b8a2fda3f7c86d0fef3c1589ef Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/lora.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/model_loading_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/model_loading_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5722da2091369c6e5742102f4014fb6c56ac4f6a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/model_loading_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_flax_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_flax_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfa06625dc8f263996a6479de19727bc1d80eb15 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_flax_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_outputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_outputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c25ed44177215df2dfe49298dc99adfca79b8f2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_outputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14e71ca738b1598f47688ae050575f7ba6cd8687 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_pytorch_flax_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1009c6bcbd0c75805de38f72e37e2e98aae500 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/modeling_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/normalization.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/normalization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8e793131aa535553e09a500f30c313c9a87a404 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/normalization.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/resnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/resnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47b74362416020a097200e3f9a83c7f2eef7534 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/resnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/resnet_flax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/resnet_flax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d11c3c7858d3dc28130c3c055903c5f9a07fd738 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/resnet_flax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/upsampling.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/upsampling.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78bfa4a6bdbaafeb198f7d1737a2d62dfc8a8b4a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/upsampling.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/vae_flax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/vae_flax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edbd16bbcba3e9bdf3a4bb1a904274989fd3e37c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/vae_flax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/vq_model.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/vq_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5dbd58713bcf97177e1b22eab89eb8aec54297a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/__pycache__/vq_model.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23665ee0532e6a7bbb63bf13a49affe62d3a9247 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/__init__.py @@ -0,0 +1,24 @@ +from .autoencoder_asym_kl import AsymmetricAutoencoderKL +from .autoencoder_dc import AutoencoderDC +from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_allegro import AutoencoderKLAllegro +from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_cosmos import AutoencoderKLCosmos +from .autoencoder_kl_flux2 import AutoencoderKLFlux2 +from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo +from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage +from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner +from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 +from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio +from .autoencoder_kl_magvit import AutoencoderKLMagvit +from .autoencoder_kl_mochi import AutoencoderKLMochi +from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder +from .autoencoder_kl_wan import AutoencoderKLWan +from .autoencoder_oobleck import AutoencoderOobleck +from .autoencoder_rae import AutoencoderRAE +from .autoencoder_tiny import AutoencoderTiny +from .consistency_decoder_vae import ConsistencyDecoderVAE +from .vq_model import VQModel diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_asym_kl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_asym_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd9b3e459f79e638abdf4fde310a16b94747a0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -0,0 +1,181 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder + + +class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with + KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + tuple of downsample block types. + down_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of down block output channels. + layers_per_down_block (`int`, *optional*, defaults to `1`): + Number layers for down block. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + tuple of upsample block types. + up_block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of up block output channels. + layers_per_up_block (`int`, *optional*, defaults to `1`): + Number layers for up block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + norm_num_groups (`int`, *optional*, defaults to `32`): + Number of groups to use for the first normalization layer in ResNet blocks. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + """ + + _skip_layerwise_casting_patterns = ["decoder"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",), + down_block_out_channels: tuple[int, ...] = (64,), + layers_per_down_block: int = 1, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + up_block_out_channels: tuple[int, ...] = (64,), + layers_per_up_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + ) -> None: + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=down_block_out_channels, + layers_per_block=layers_per_down_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = MaskConditionDecoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=up_block_out_channels, + layers_per_block=layers_per_up_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.register_to_config(block_out_channels=up_block_out_channels) + self.register_to_config(force_upcast=False) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput | tuple[torch.Tensor]: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + image: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + return_dict: bool = True, + ) -> DecoderOutput | tuple[torch.Tensor]: + z = self.post_quant_conv(z) + dec = self.decoder(z, image, mask) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + generator: torch.Generator | None = None, + image: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + return_dict: bool = True, + ) -> DecoderOutput | tuple[torch.Tensor]: + decoded = self._decode(z, image, mask).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + mask: torch.Tensor | None = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | tuple[torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, generator, sample, mask).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_dc.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_dc.py new file mode 100644 index 0000000000000000000000000000000000000000..02a83d79aba56a69df9946a5a91ae878eccd9d62 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_dc.py @@ -0,0 +1,713 @@ +# Copyright 2025 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import SanaMultiscaleLinearAttention +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm, get_normalization +from ..transformers.sana_transformer import GLUMBConv +from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput + + +class ResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + norm_type: str = "batch_norm", + act_fn: str = "relu6", + ) -> None: + super().__init__() + + self.norm_type = norm_type + + self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity() + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False) + self.norm = get_normalization(norm_type, out_channels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.norm_type == "rms_norm": + # move channel to the last dimension so we apply RMSnorm across channel dimension + hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + residual + + +class EfficientViTBlock(nn.Module): + def __init__( + self, + in_channels: int, + mult: float = 1.0, + attention_head_dim: int = 32, + qkv_multiscales: tuple[int, ...] = (5,), + norm_type: str = "batch_norm", + ) -> None: + super().__init__() + + self.attn = SanaMultiscaleLinearAttention( + in_channels=in_channels, + out_channels=in_channels, + mult=mult, + attention_head_dim=attention_head_dim, + norm_type=norm_type, + kernel_sizes=qkv_multiscales, + residual_connection=True, + ) + + self.conv_out = GLUMBConv( + in_channels=in_channels, + out_channels=in_channels, + norm_type="rms_norm", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn(x) + x = self.conv_out(x) + return x + + +def get_block( + block_type: str, + in_channels: int, + out_channels: int, + attention_head_dim: int, + norm_type: str, + act_fn: str, + qkv_multiscales: tuple[int, ...] = (), +): + if block_type == "ResBlock": + block = ResBlock(in_channels, out_channels, norm_type, act_fn) + + elif block_type == "EfficientViTBlock": + block = EfficientViTBlock( + in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales + ) + + else: + raise ValueError(f"Block with {block_type=} is not supported.") + + return block + + +class DCDownBlock2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None: + super().__init__() + + self.downsample = downsample + self.factor = 2 + self.stride = 1 if downsample else 2 + self.group_size = in_channels * self.factor**2 // out_channels + self.shortcut = shortcut + + out_ratio = self.factor**2 + if downsample: + assert out_channels % out_ratio == 0 + out_channels = out_channels // out_ratio + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=self.stride, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self.conv(hidden_states) + if self.downsample: + x = F.pixel_unshuffle(x, self.factor) + + if self.shortcut: + y = F.pixel_unshuffle(hidden_states, self.factor) + y = y.unflatten(1, (-1, self.group_size)) + y = y.mean(dim=2) + hidden_states = x + y + else: + hidden_states = x + + return hidden_states + + +class DCUpBlock2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + interpolate: bool = False, + shortcut: bool = True, + interpolation_mode: str = "nearest", + ) -> None: + super().__init__() + + self.interpolate = interpolate + self.interpolation_mode = interpolation_mode + self.shortcut = shortcut + self.factor = 2 + self.repeats = out_channels * self.factor**2 // in_channels + + out_ratio = self.factor**2 + + if not interpolate: + out_channels = out_channels * out_ratio + + self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.interpolate: + x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode) + x = self.conv(x) + else: + x = self.conv(hidden_states) + x = F.pixel_shuffle(x, self.factor) + + if self.shortcut: + y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) + y = F.pixel_shuffle(y, self.factor) + hidden_states = x + y + else: + hidden_states = x + + return hidden_states + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + latent_channels: int, + attention_head_dim: int = 32, + block_type: str | tuple[str] = "ResBlock", + block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024), + layers_per_block: tuple[int, ...] = (2, 2, 2, 2, 2, 2), + qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), + downsample_block_type: str = "pixel_unshuffle", + out_shortcut: bool = True, + ): + super().__init__() + + num_blocks = len(block_out_channels) + + if isinstance(block_type, str): + block_type = (block_type,) * num_blocks + + if layers_per_block[0] > 0: + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], + kernel_size=3, + stride=1, + padding=1, + ) + else: + self.conv_in = DCDownBlock2d( + in_channels=in_channels, + out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], + downsample=downsample_block_type == "pixel_unshuffle", + shortcut=False, + ) + + down_blocks = [] + for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)): + down_block_list = [] + + for _ in range(num_layers): + block = get_block( + block_type[i], + out_channel, + out_channel, + attention_head_dim=attention_head_dim, + norm_type="rms_norm", + act_fn="silu", + qkv_multiscales=qkv_multiscales[i], + ) + down_block_list.append(block) + + if i < num_blocks - 1 and num_layers > 0: + downsample_block = DCDownBlock2d( + in_channels=out_channel, + out_channels=block_out_channels[i + 1], + downsample=downsample_block_type == "pixel_unshuffle", + shortcut=True, + ) + down_block_list.append(downsample_block) + + down_blocks.append(nn.Sequential(*down_block_list)) + + self.down_blocks = nn.ModuleList(down_blocks) + + self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1) + + self.out_shortcut = out_shortcut + if out_shortcut: + self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + if self.out_shortcut: + x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) + x = x.mean(dim=2) + hidden_states = self.conv_out(hidden_states) + x + else: + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class Decoder(nn.Module): + def __init__( + self, + in_channels: int, + latent_channels: int, + attention_head_dim: int = 32, + block_type: str | tuple[str] = "ResBlock", + block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024), + layers_per_block: tuple[int, ...] = (2, 2, 2, 2, 2, 2), + qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), + norm_type: str | tuple[str] = "rms_norm", + act_fn: str | tuple[str] = "silu", + upsample_block_type: str = "pixel_shuffle", + in_shortcut: bool = True, + conv_act_fn: str = "relu", + ): + super().__init__() + + num_blocks = len(block_out_channels) + + if isinstance(block_type, str): + block_type = (block_type,) * num_blocks + if isinstance(norm_type, str): + norm_type = (norm_type,) * num_blocks + if isinstance(act_fn, str): + act_fn = (act_fn,) * num_blocks + + self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1) + + self.in_shortcut = in_shortcut + if in_shortcut: + self.in_shortcut_repeats = block_out_channels[-1] // latent_channels + + up_blocks = [] + for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))): + up_block_list = [] + + if i < num_blocks - 1 and num_layers > 0: + upsample_block = DCUpBlock2d( + block_out_channels[i + 1], + out_channel, + interpolate=upsample_block_type == "interpolate", + shortcut=True, + ) + up_block_list.append(upsample_block) + + for _ in range(num_layers): + block = get_block( + block_type[i], + out_channel, + out_channel, + attention_head_dim=attention_head_dim, + norm_type=norm_type[i], + act_fn=act_fn[i], + qkv_multiscales=qkv_multiscales[i], + ) + up_block_list.append(block) + + up_blocks.insert(0, nn.Sequential(*up_block_list)) + + self.up_blocks = nn.ModuleList(up_blocks) + + channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1] + + self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True) + self.conv_act = get_activation(conv_act_fn) + self.conv_out = None + + if layers_per_block[0] > 0: + self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1) + else: + self.conv_out = DCUpBlock2d( + channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.in_shortcut: + x = hidden_states.repeat_interleave( + self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats + ) + hidden_states = self.conv_in(hidden_states) + x + else: + hidden_states = self.conv_in(hidden_states) + + for up_block in reversed(self.up_blocks): + hidden_states = up_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in + [SANA](https://huggingface.co/papers/2410.10629). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + The number of input channels in samples. + latent_channels (`int`, defaults to `32`): + The number of channels in the latent space representation. + encoder_block_types (`str | tuple[str]`, defaults to `"ResBlock"`): + The type(s) of block to use in the encoder. + decoder_block_types (`str | tuple[str]`, defaults to `"ResBlock"`): + The type(s) of block to use in the decoder. + encoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`): + The number of output channels for each block in the encoder. + decoder_block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`): + The number of output channels for each block in the decoder. + encoder_layers_per_block (`tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`): + The number of layers per block in the encoder. + decoder_layers_per_block (`tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`): + The number of layers per block in the decoder. + encoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`): + Multi-scale configurations for the encoder's QKV (query-key-value) transformations. + decoder_qkv_multiscales (`tuple[tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`): + Multi-scale configurations for the decoder's QKV (query-key-value) transformations. + upsample_block_type (`str`, defaults to `"pixel_shuffle"`): + The type of block to use for upsampling in the decoder. + downsample_block_type (`str`, defaults to `"pixel_unshuffle"`): + The type of block to use for downsampling in the encoder. + decoder_norm_types (`str | tuple[str]`, defaults to `"rms_norm"`): + The normalization type(s) to use in the decoder. + decoder_act_fns (`str | tuple[str]`, defaults to `"silu"`): + The activation function(s) to use in the decoder. + encoder_out_shortcut (`bool`, defaults to `True`): + Whether to use shortcut at the end of the encoder. + decoder_in_shortcut (`bool`, defaults to `True`): + Whether to use shortcut at the beginning of the decoder. + decoder_conv_act_fn (`str`, defaults to `"relu"`): + The activation function to use at the end of the decoder. + scaling_factor (`float`, defaults to `1.0`): + The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent + space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = + z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back + to the original scale with the formula: `z = 1 / scaling_factor * z`. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + in_channels: int = 3, + latent_channels: int = 32, + attention_head_dim: int = 32, + encoder_block_types: str | tuple[str] = "ResBlock", + decoder_block_types: str | tuple[str] = "ResBlock", + encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024), + decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024), + encoder_layers_per_block: tuple[int, ...] = (2, 2, 2, 3, 3, 3), + decoder_layers_per_block: tuple[int, ...] = (3, 3, 3, 3, 3, 3), + encoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), + decoder_qkv_multiscales: tuple[tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), + upsample_block_type: str = "pixel_shuffle", + downsample_block_type: str = "pixel_unshuffle", + decoder_norm_types: str | tuple[str] = "rms_norm", + decoder_act_fns: str | tuple[str] = "silu", + encoder_out_shortcut: bool = True, + decoder_in_shortcut: bool = True, + decoder_conv_act_fn: str = "relu", + scaling_factor: float = 1.0, + ) -> None: + super().__init__() + + self.encoder = Encoder( + in_channels=in_channels, + latent_channels=latent_channels, + attention_head_dim=attention_head_dim, + block_type=encoder_block_types, + block_out_channels=encoder_block_out_channels, + layers_per_block=encoder_layers_per_block, + qkv_multiscales=encoder_qkv_multiscales, + downsample_block_type=downsample_block_type, + out_shortcut=encoder_out_shortcut, + ) + self.decoder = Decoder( + in_channels=in_channels, + latent_channels=latent_channels, + attention_head_dim=attention_head_dim, + block_type=decoder_block_types, + block_out_channels=decoder_block_out_channels, + layers_per_block=decoder_layers_per_block, + qkv_multiscales=decoder_qkv_multiscales, + norm_type=decoder_norm_types, + act_fn=decoder_act_fns, + upsample_block_type=upsample_block_type, + in_shortcut=decoder_in_shortcut, + conv_act_fn=decoder_conv_act_fn, + ) + + self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1) + self.temporal_compression_ratio = 1 + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + ) -> None: + r""" + Enable tiled AE decoding. When this option is enabled, the AE will split the input tensor into tiles to compute + decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x, return_dict=False)[0] + + encoded = self.encoder(x) + + return encoded + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput | tuple[torch.Tensor]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.vae.EncoderOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.vae.EncoderOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + encoded = torch.cat(encoded_slices) + else: + encoded = self._encode(x) + + if not return_dict: + return (encoded,) + return EncoderOutput(latent=encoded) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=False)[0] + + decoded = self.decoder(z) + + return decoded + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.size(0) > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, x.shape[2], self.tile_sample_stride_height): + row = [] + for j in range(0, x.shape[3], self.tile_sample_stride_width): + tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + if ( + tile.shape[2] % self.spatial_compression_ratio != 0 + or tile.shape[3] % self.spatial_compression_ratio != 0 + ): + pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio + pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio + tile = F.pad(tile, (0, pad_w, 0, pad_h)) + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + + if not return_dict: + return (encoded,) + return EncoderOutput(latent=encoded) + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, height, width = z.shape + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + decoded = torch.cat(result_rows, dim=2) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor: + encoded = self.encode(sample, return_dict=False)[0] + decoded = self.decode(encoded, return_dict=False)[0] + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e7318f56795d0a21519767f104908cd3cfe42e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py @@ -0,0 +1,471 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import deprecate +from ...utils.accelerate_utils import apply_forward_hook +from ..attention import AttentionMixin +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class AutoencoderKL( + ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin +): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + tuple of downsample block types. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + tuple of upsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str] = ("DownEncoderBlock2D",), + up_block_types: tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + shift_factor: float | None = None, + latents_mean: tuple[float] | None = None, + latents_std: tuple[float] | None = None, + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> DecoderOutput | torch.FloatTensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_allegro.py new file mode 100644 index 0000000000000000000000000000000000000000..463f8f41bc1047bc6501cbc449ffadffa0d76af5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -0,0 +1,1102 @@ +# Copyright 2025 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..attention_processor import Attention, SpatialNorm +from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from ..downsampling import Downsample2D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..resnet import ResnetBlock2D +from ..upsampling import Upsample2D +from .vae import AutoencoderMixin + + +class AllegroTemporalConvLayer(nn.Module): + r""" + Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__( + self, + in_dim: int, + out_dim: int | None = None, + dropout: float = 0.0, + norm_num_groups: int = 32, + up_sample: bool = False, + down_sample: bool = False, + stride: int = 1, + ) -> None: + super().__init__() + + out_dim = out_dim or in_dim + pad_h = pad_w = int((stride - 1) * 0.5) + pad_t = 0 + + self.down_sample = down_sample + self.up_sample = up_sample + + if down_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)), + ) + elif up_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)), + ) + else: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), + ) + + @staticmethod + def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + return hidden_states + + def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + if self.down_sample: + identity = hidden_states[:, :, ::2] + elif self.up_sample: + identity = hidden_states.repeat_interleave(2, dim=2, output_size=hidden_states.shape[2] * 2) + else: + identity = hidden_states + + if self.down_sample or self.up_sample: + hidden_states = self.conv1(hidden_states) + else: + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.up_sample: + hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv2(hidden_states) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv3(hidden_states) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + return hidden_states + + +class AllegroDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + spatial_downsample: bool = True, + temporal_downsample: bool = False, + downsample_padding: int = 1, + ): + super().__init__() + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + AllegroTemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if temporal_downsample: + self.temp_convs_down = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3 + ) + self.add_temp_downsample = temporal_downsample + + if spatial_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + if self.add_temp_downsample: + hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + temb_channels: int | None = None, + ): + super().__init__() + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + AllegroTemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + self.add_temp_upsample = temporal_upsample + if temporal_upsample: + self.temp_conv_up = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3 + ) + + if spatial_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + if self.add_temp_upsample: + hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroMidBlock3DConv(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + AllegroTemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + temp_convs.append( + AllegroTemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.resnets[0](hidden_states, temb=None) + + hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size) + + for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroEncoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: tuple[bool, ...] = [True, True, False, False], + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + in_channels=block_out_channels[0], + out_channels=block_out_channels[0], + kernel_size=(3, 1, 1), + padding=(1, 0, 0), + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "AllegroDownBlock3D": + down_block = AllegroDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + spatial_downsample=not is_final_block, + temporal_downsample=temporal_downsample_blocks[i], + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid + self.mid_block = AllegroMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + + self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_in(sample) + sample = sample + residual + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # Down blocks + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + + # Mid block + sample = self._gradient_checkpointing_func(self.mid_block, sample) + else: + # Down blocks + for down_block in self.down_blocks: + sample = down_block(sample) + + # Mid block + sample = self.mid_block(sample) + + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_out(sample) + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_out(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample + + +class AllegroDecoder3D(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + temporal_upsample_blocks: tuple[bool, ...] = [False, True, True, False], + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = AllegroMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "AllegroUpBlock3D": + up_block = AllegroUpBlock3D( + num_layers=layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + spatial_upsample=not is_final_block, + temporal_upsample=temporal_upsample_blocks[i], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + else: + raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`") + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + + self.conv_act = nn.SiLU() + + self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_in(sample) + sample = sample + residual + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # Mid block + sample = self._gradient_checkpointing_func(self.mid_block, sample) + + # Up blocks + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample) + + else: + # Mid block + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # Up blocks + for up_block in self.up_blocks: + sample = up_block(sample) + + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_out(sample) + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_out(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample + + +class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in + [Allegro](https://github.com/rhymes-ai/Allegro). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, defaults to `3`): + Number of channels in the input image. + out_channels (int, defaults to `3`): + Number of channels in the output. + down_block_types (`tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`): + tuple of strings denoting which types of down blocks to use. + up_block_types (`tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`): + tuple of strings denoting which types of up blocks to use. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + tuple of integers denoting number of output channels in each block. + temporal_downsample_blocks (`tuple[bool, ...]`, defaults to `(True, True, False, False)`): + tuple of booleans denoting which blocks to enable temporal downsampling in. + latent_channels (`int`, defaults to `4`): + Number of channels in latents. + layers_per_block (`int`, defaults to `2`): + Number of resnet or attention or temporal convolution layers per down/up block. + act_fn (`str`, defaults to `"silu"`): + The activation function to use. + norm_num_groups (`int`, defaults to `32`): + Number of groups to use in normalization layers. + temporal_compression_ratio (`int`, defaults to `4`): + Ratio by which temporal dimension of samples are compressed. + sample_size (`int`, defaults to `320`): + Default latent size. + scaling_factor (`float`, defaults to `0.13235`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + force_upcast (`bool`, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types: tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: tuple[bool, ...] = (True, True, False, False), + temporal_upsample_blocks: tuple[bool, ...] = (False, True, True, False), + latent_channels: int = 4, + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_size: int = 320, + scaling_factor: float = 0.13, + force_upcast: bool = True, + ) -> None: + super().__init__() + + self.encoder = AllegroEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + temporal_downsample_blocks=temporal_downsample_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + self.decoder = AllegroDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + temporal_upsample_blocks=temporal_upsample_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need + # to use a specific parameter here or in other VAEs. + + self.use_slicing = False + self.use_tiling = False + + self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1) + self.tile_overlap_t = 8 + self.tile_overlap_h = 120 + self.tile_overlap_w = 80 + sample_frames = 24 + + self.kernel = (sample_frames, sample_size, sample_size) + self.stride = ( + sample_frames - self.tile_overlap_t, + sample_size - self.tile_overlap_h, + sample_size - self.tile_overlap_w, + ) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + # TODO(aryan) + # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + if self.use_tiling: + return self.tiled_encode(x) + + raise NotImplementedError("Encoding without tiling has not been implemented yet.") + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of videos into latents. + + Args: + x (`torch.Tensor`): + Input batch of videos. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # TODO(aryan): refactor tiling implementation + # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + if self.use_tiling: + return self.tiled_decode(z) + + raise NotImplementedError("Decoding without tiling has not been implemented yet.") + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of videos. + + Args: + z (`torch.Tensor`): + Input batch of latent vectors. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio + + batch_size, num_channels, num_frames, height, width = x.shape + + output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1 + output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1 + output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1 + + count = 0 + output_latent = x.new_zeros( + ( + output_num_frames * output_height * output_width, + 2 * self.config.latent_channels, + self.kernel[0] // rt, + self.kernel[1] // rs, + self.kernel[2] // rs, + ) + ) + vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + + video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = video_cube + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): + latent = self.encoder(vae_batch_input) + + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1] + else: + output_latent[count - local_batch_size + 1 : count + 1] = latent + + vae_batch_input = x.new_zeros( + (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]) + ) + + count += 1 + + latent = x.new_zeros( + (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs) + ) + output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + output_overlap = ( + output_kernel[0] - output_stride[0], + output_kernel[1] - output_stride[1], + output_kernel[2] - output_stride[2], + ) + + for i in range(output_num_frames): + n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0] + for j in range(output_height): + h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1] + for k in range(output_width): + w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2] + latent_mean = _prepare_for_blend( + (i, output_num_frames, output_overlap[0]), + (j, output_height, output_overlap[1]), + (k, output_width, output_overlap[2]), + output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0), + ) + latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean + + latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1) + latent = self.quant_conv(latent) + latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return latent + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio + + latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + + batch_size, num_channels, num_frames, height, width = z.shape + + ## post quant conv (a mapping) + z = z.permute(0, 2, 1, 3, 4).flatten(0, 1) + z = self.post_quant_conv(z) + z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1 + output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1 + output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1 + + count = 0 + decoded_videos = z.new_zeros( + ( + output_num_frames * output_height * output_width, + self.config.out_channels, + self.kernel[0], + self.kernel[1], + self.kernel[2], + ) + ) + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) + ) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0] + h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1] + w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2] + + current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = current_latent + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): + current_video = self.decoder(vae_batch_input) + + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + decoded_videos[count - count % local_batch_size :] = current_video[ + : count % local_batch_size + 1 + ] + else: + decoded_videos[count - local_batch_size + 1 : count + 1] = current_video + + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) + ) + + count += 1 + + video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs)) + video_overlap = ( + self.kernel[0] - self.stride[0], + self.kernel[1] - self.stride[1], + self.kernel[2] - self.stride[2], + ) + + for i in range(output_num_frames): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + for j in range(output_height): + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + for k in range(output_width): + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + out_video_blend = _prepare_for_blend( + (i, output_num_frames, video_overlap[0]), + (j, output_height, video_overlap[1]), + (k, output_width, video_overlap[2]), + decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0), + ) + video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend + + video = video.permute(0, 2, 1, 3, 4).contiguous() + return video + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + PyTorch random number generator. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +def _prepare_for_blend(n_param, h_param, w_param, x): + # TODO(aryan): refactor + n, n_max, overlap_n = n_param + h, h_max, overlap_h = h_param + w, w_max, overlap_w = w_param + if overlap_n > 0: + if n > 0: # the head overlap part decays from 0 to 1 + x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * ( + torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if n < n_max - 1: # the tail overlap part decays from 1 to 0 + x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * ( + 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if h > 0: + x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * ( + torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if h < h_max - 1: + x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * ( + 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if w > 0: + x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * ( + torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + if w < w_max - 1: + x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * ( + 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + return x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py new file mode 100644 index 0000000000000000000000000000000000000000..9921e3932465982040720931c32de1efd35d7c11 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -0,0 +1,1421 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..downsampling import CogVideoXDownsample3D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..upsampling import CogVideoXUpsample3D +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogVideoXSafeConv3d(nn.Conv3d): + r""" + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + memory_count = ( + (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 + ) + + # Set to 2GB, suitable for CuDNN + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) + + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super().forward(input) + + +class CogVideoXCausalConv3d(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int`, defaults to `1`): Stride of the convolution. + dilation (`int`, defaults to `1`): Dilation rate of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int = 1, + dilation: int = 1, + pad_mode: str = "constant", + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + # TODO(aryan): configure calculation based on stride and dilation in the future. + # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi + time_pad = time_kernel_size - 1 + height_pad = (height_kernel_size - 1) // 2 + width_pad = (width_kernel_size - 1) // 2 + + self.pad_mode = pad_mode + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + self.const_padding_conv3d = (0, self.width_pad, self.height_pad) + + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSafeConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d, + padding_mode="zeros", + ) + + def fake_context_parallel_forward( + self, inputs: torch.Tensor, conv_cache: torch.Tensor | None = None + ) -> torch.Tensor: + if self.pad_mode == "replicate": + inputs = F.pad(inputs, self.time_causal_padding, mode="replicate") + else: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + inputs = torch.cat(cached_inputs + [inputs], dim=2) + return inputs + + def forward(self, inputs: torch.Tensor, conv_cache: torch.Tensor | None = None) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + + if self.pad_mode == "replicate": + conv_cache = None + else: + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + output = self.conv(inputs) + return output, conv_cache + + +class CogVideoXSpatialNorm3D(nn.Module): + r""" + Spatially conditioned normalization as defined in https://huggingface.co/papers/2209.09002. This implementation is + specific to 3D-video like data. + + CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + groups (`int`): + Number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + groups: int = 32, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward( + self, f: torch.Tensor, zq: torch.Tensor, conv_cache: dict[str, torch.Tensor] | None = None + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first_size) + z_rest = F.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + + conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + norm_f = self.norm_layer(f) + new_f = norm_f * conv_y + conv_b + return new_f, new_conv_cache + + +class CogVideoXResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "swish", + conv_shortcut: bool = False, + spatial_norm_dim: int | None = None, + pad_mode: str = "first", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + self.spatial_norm_dim = spatial_norm_dim + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D( + f_channels=in_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + self.norm2 = CogVideoXSpatialNorm3D( + f_channels=out_channels, + zq_channels=spatial_norm_dim, + groups=groups, + ) + + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if temb_channels > 0: + self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels) + + self.dropout = nn.Dropout(dropout) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode + ) + else: + self.conv_shortcut = CogVideoXSafeConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward( + self, + inputs: torch.Tensor, + temb: torch.Tensor | None = None, + zq: torch.Tensor | None = None, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = inputs + + if zq is not None: + hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1")) + else: + hidden_states = self.norm1(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) + + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2")) + else: + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut( + inputs, conv_cache=conv_cache.get("conv_shortcut") + ) + else: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states, new_conv_cache + + +class CogVideoXDownBlock3D(nn.Module): + r""" + A downsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 0, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.downsamplers = None + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + CogVideoXDownsample3D( + out_channels, out_channels, padding=downsample_padding, compress_time=compress_time + ) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + zq: torch.Tensor | None = None, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXDownBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, + hidden_states, + temb, + zq, + conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, new_conv_cache + + +class CogVideoXMidBlock3D(nn.Module): + r""" + A middle block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: int | None = None, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + spatial_norm_dim=spatial_norm_dim, + non_linearity=resnet_act_fn, + pad_mode=pad_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + zq: torch.Tensor | None = None, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, temb, zq, conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) + + return hidden_states, new_conv_cache + + +class CogVideoXUpBlock3D(nn.Module): + r""" + An upsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, defaults to `16`): + The dimension to use for spatial norm if it is to be used instead of group norm. + add_upsample (`bool`, defaults to `True`): + Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + spatial_norm_dim: int = 16, + add_upsample: bool = True, + upsample_padding: int = 1, + compress_time: bool = False, + pad_mode: str = "first", + ): + super().__init__() + + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append( + CogVideoXResnetBlock3D( + in_channels=in_channel, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_norm_dim=spatial_norm_dim, + pad_mode=pad_mode, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.upsamplers = None + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + CogVideoXUpsample3D( + out_channels, out_channels, padding=upsample_padding, compress_time=compress_time + ) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + zq: torch.Tensor | None = None, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""Forward method of the `CogVideoXUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, + hidden_states, + temb, + zq, + conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key) + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states, new_conv_cache + + +class CogVideoXEncoder3D(nn.Module): + r""" + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + down_block_types: tuple[str, ...] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + # log2 of temporal_compress_times + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.down_blocks = nn.ModuleList([]) + + # down blocks + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if down_block_type == "CogVideoXDownBlock3D": + down_block = CogVideoXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + compress_time=compress_time, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=block_out_channels[-1], + temb_channels=0, + dropout=dropout, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + temb: torch.Tensor | None = None, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""The forward method of the `CogVideoXEncoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # 1. Down + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + down_block, + hidden_states, + temb, + None, + conv_cache.get(conv_cache_key), + ) + + # 2. Mid + hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func( + self.mid_block, + hidden_states, + temb, + None, + conv_cache.get("mid_block"), + ) + else: + # 1. Down + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, temb, None, conv_cache.get(conv_cache_key) + ) + + # 2. Mid + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, None, conv_cache=conv_cache.get("mid_block") + ) + + # 3. Post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache + + +class CogVideoXDecoder3D(nn.Module): + r""" + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 256, 512), + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + dropout: float = 0.0, + pad_mode: str = "first", + temporal_compression_ratio: float = 4, + ): + super().__init__() + + reversed_block_out_channels = list(reversed(block_out_channels)) + + self.conv_in = CogVideoXCausalConv3d( + in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode + ) + + # mid block + self.mid_block = CogVideoXMidBlock3D( + in_channels=reversed_block_out_channels[0], + temb_channels=0, + num_layers=2, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + pad_mode=pad_mode, + ) + + # up blocks + self.up_blocks = nn.ModuleList([]) + + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + if up_block_type == "CogVideoXUpBlock3D": + up_block = CogVideoXUpBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=0, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final_block, + compress_time=compress_time, + pad_mode=pad_mode, + ) + prev_output_channel = output_channel + else: + raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`") + + self.up_blocks.append(up_block) + + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups) + self.conv_act = nn.SiLU() + self.conv_out = CogVideoXCausalConv3d( + reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + temb: torch.Tensor | None = None, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""The forward method of the `CogVideoXDecoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # 1. Mid + hidden_states, new_conv_cache["mid_block"] = self._gradient_checkpointing_func( + self.mid_block, + hidden_states, + temb, + sample, + conv_cache.get("mid_block"), + ) + + # 2. Up + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + up_block, + hidden_states, + temb, + sample, + conv_cache.get(conv_cache_key), + ) + else: + # 1. Mid + hidden_states, new_conv_cache["mid_block"] = self.mid_block( + hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block") + ) + + # 2. Up + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key) + ) + + # 3. Post-process + hidden_states, new_conv_cache["norm_out"] = self.norm_out( + hidden_states, sample, conv_cache=conv_cache.get("norm_out") + ) + hidden_states = self.conv_act(hidden_states) + hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out")) + + return hidden_states, new_conv_cache + + +class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [CogVideoX](https://github.com/THUDM/CogVideo). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + tuple of downsample block types. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + tuple of upsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str] = ( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types: tuple[str] = ( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels: tuple[int] = (128, 256, 256, 512), + latent_channels: int = 16, + layers_per_block: int = 3, + act_fn: str = "silu", + norm_eps: float = 1e-6, + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_height: int = 480, + sample_width: int = 720, + scaling_factor: float = 1.15258426, + shift_factor: float | None = None, + latents_mean: tuple[float] | None = None, + latents_std: tuple[float] | None = None, + force_upcast: float = True, + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + invert_scale_latents: bool = False, + ): + super().__init__() + + self.encoder = CogVideoXEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = CogVideoXDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_eps=norm_eps, + norm_num_groups=norm_num_groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None + + self.use_slicing = False + self.use_tiling = False + + # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not + # recommended because the temporal parts of the VAE, here, are tricky to understand. + # If you decode X latent frames together, the number of output frames is: + # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames + # + # Example with num_latent_frames_batch_size = 2: + # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together + # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 6 * 8 = 48 frames + # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together + # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + + # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) + # => 1 * 9 + 5 * 8 = 49 frames + # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that + # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different + # number of temporal frames. + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + # We make the minimum height and width of sample for tiling half that of the generally supported + self.tile_sample_min_height = sample_height // 2 + self.tile_sample_min_width = sample_width // 2 + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + + # These are experimental overlap factors that were chosen based on experimentation and seem to work best for + # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX + # and so the tiling implementation has only been tested on those specific resolutions. + self.tile_overlap_factor_height = 1 / 6 + self.tile_overlap_factor_width = 1 / 5 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_overlap_factor_height: float | None = None, + tile_overlap_factor_width: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = int( + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + frame_batch_size = self.num_sample_frames_batch_size + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + # As the extra single frame is handled inside the loop, it is not required to round up here. + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + enc = [] + + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + + enc = torch.cat(enc, dim=2) + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + frame_batch_size = self.num_latent_frames_batch_size + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + dec = [] + + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + + dec = torch.cat(dec, dim=2) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + # For a rough memory estimate, take a look at the `tiled_decode` method. + batch_size, num_channels, num_frames, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. + # As the extra single frame is handled inside the loop, it is not required to round up here. + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[ + :, + :, + start_frame:end_frame, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile, conv_cache = self.encoder(tile, conv_cache=conv_cache) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3) + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> torch.Tensor | torch.Tensor: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_flux2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..c1345d5de73f503af32114b05db0c089a6c8349a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_flux2.py @@ -0,0 +1,487 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import deprecate +from ...utils.accelerate_utils import apply_forward_hook +from ..attention import AttentionMixin +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class AutoencoderKLFlux2( + ModelMixin, AutoencoderMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin +): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + up_block_types: tuple[str, ...] = ( + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ), + block_out_channels: tuple[int, ...] = ( + 128, + 256, + 512, + 512, + ), + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = 32, + norm_num_groups: int = 32, + sample_size: int = 1024, # YiYi notes: not sure + force_upcast: bool = True, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + mid_block_add_attention: bool = True, + batch_norm_eps: float = 1e-4, + batch_norm_momentum: float = 0.1, + patch_size: tuple[int, int] = (2, 2), + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.bn = nn.BatchNorm2d( + math.prod(patch_size) * latent_channels, + eps=batch_norm_eps, + momentum=batch_norm_momentum, + affine=False, + track_running_stats=True, + ) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> DecoderOutput | torch.FloatTensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py new file mode 100644 index 0000000000000000000000000000000000000000..6922ac8535541ea0198ab89d7db589f0a407140a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py @@ -0,0 +1,687 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanImageResnetBlock(nn.Module): + r""" + Residual block with two convolutions and optional channel change. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if in_channels != out_channels: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x): + # Apply shortcut connection + residual = x + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + x = self.conv1(x) + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + + if self.conv_shortcut is not None: + x = self.conv_shortcut(x) + # Add residual connection + return x + residual + + +class HunyuanImageAttentionBlock(nn.Module): + r""" + Self-attention with a single head. + + Args: + in_channels (int): The number of channels in the input tensor. + """ + + def __init__(self, in_channels: int): + super().__init__() + + # layers + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.to_q = nn.Conv2d(in_channels, in_channels, 1) + self.to_k = nn.Conv2d(in_channels, in_channels, 1) + self.to_v = nn.Conv2d(in_channels, in_channels, 1) + self.proj = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + identity = x + x = self.norm(x) + + # compute query, key, value + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + batch_size, channels, height, width = query.shape + query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous() + key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous() + value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous() + + # apply attention + x = F.scaled_dot_product_attention(query, key, value) + + x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + # output projection + x = self.proj(x) + + return x + identity + + +class HunyuanImageDownsample(nn.Module): + """ + Downsampling block for spatial reduction. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + factor = 4 + if out_channels % factor != 0: + raise ValueError(f"out_channels % factor != 0: {out_channels % factor}") + + self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + self.group_size = factor * in_channels // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.conv(x) + + B, C, H, W = h.shape + h = h.reshape(B, C, H // 2, 2, W // 2, 2) + h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w + h = h.reshape(B, 4 * C, H // 2, W // 2) + + B, C, H, W = x.shape + shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2) + shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w + shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2) + + B, C, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2) + return h + shortcut + + +class HunyuanImageUpsample(nn.Module): + """ + Upsampling block for spatial expansion. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + factor = 4 + self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) + self.repeats = factor * out_channels // in_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.conv(x) + + B, C, H, W = h.shape + h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w + h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2 + h = h.reshape(B, C // 4, H * 2, W * 2) + + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + + B, C, H, W = shortcut.shape + shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w + shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2 + shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2) + return h + shortcut + + +class HunyuanImageMidBlock(nn.Module): + """ + Middle block for HunyuanImageVAE encoder and decoder. + + Args: + in_channels (int): Number of input channels. + num_layers (int): Number of layers. + """ + + def __init__(self, in_channels: int, num_layers: int = 1): + super().__init__() + + resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)] + + attentions = [] + for _ in range(num_layers): + attentions.append(HunyuanImageAttentionBlock(in_channels)) + resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.resnets[0](x) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + x = attn(x) + x = resnet(x) + + return x + + +class HunyuanImageEncoder2D(nn.Module): + r""" + Encoder network that compresses input to latent representation. + + Args: + in_channels (int): Number of input channels. + z_channels (int): Number of latent channels. + block_out_channels (list of int): Output channels for each block. + num_res_blocks (int): Number of residual blocks per block. + spatial_compression_ratio (int): Spatial downsampling factor. + non_linearity (str): Type of non-linearity to use. Default is "silu". + downsample_match_channel (bool): Whether to match channels during downsampling. + """ + + def __init__( + self, + in_channels: int, + z_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int, + spatial_compression_ratio: int, + non_linearity: str = "silu", + downsample_match_channel: bool = True, + ): + super().__init__() + if block_out_channels[-1] % (2 * z_channels) != 0: + raise ValueError( + f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}" + ) + + self.in_channels = in_channels + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.spatial_compression_ratio = spatial_compression_ratio + + self.group_size = block_out_channels[-1] // (2 * z_channels) + self.nonlinearity = get_activation(non_linearity) + + # init block + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + + block_in_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + block_out_channel = block_out_channels[i] + # residual blocks + for _ in range(num_res_blocks): + self.down_blocks.append( + HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel) + ) + block_in_channel = block_out_channel + + # downsample block + if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1: + if downsample_match_channel: + block_out_channel = block_out_channels[i + 1] + self.down_blocks.append( + HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel) + ) + block_in_channel = block_out_channel + + # middle blocks + self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1) + + # output blocks + # Output layers + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_in(x) + + ## downsamples + for down_block in self.down_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(down_block, x) + else: + x = down_block(x) + + ## middle + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(self.mid_block, x) + else: + x = self.mid_block(x) + + ## head + B, C, H, W = x.shape + residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2) + + x = self.norm_out(x) + x = self.nonlinearity(x) + x = self.conv_out(x) + return x + residual + + +class HunyuanImageDecoder2D(nn.Module): + r""" + Decoder network that reconstructs output from latent representation. + + Args: + z_channels : int + Number of latent channels. + out_channels : int + Number of output channels. + block_out_channels : tuple[int, ...] + Output channels for each block. + num_res_blocks : int + Number of residual blocks per block. + spatial_compression_ratio : int + Spatial upsampling factor. + upsample_match_channel : bool + Whether to match channels during upsampling. + non_linearity (str): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int, + spatial_compression_ratio: int, + upsample_match_channel: bool = True, + non_linearity: str = "silu", + ): + super().__init__() + if block_out_channels[0] % z_channels != 0: + raise ValueError( + f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}" + ) + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.repeat = block_out_channels[0] // z_channels + self.spatial_compression_ratio = spatial_compression_ratio + self.nonlinearity = get_activation(non_linearity) + + self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + # Middle blocks with attention + self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1) + + # Upsampling blocks + block_in_channel = block_out_channels[0] + self.up_blocks = nn.ModuleList() + for i in range(len(block_out_channels)): + block_out_channel = block_out_channels[i] + for _ in range(self.num_res_blocks + 1): + self.up_blocks.append( + HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel) + ) + block_in_channel = block_out_channel + + if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1: + if upsample_match_channel: + block_out_channel = block_out_channels[i + 1] + self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel)) + block_in_channel = block_out_channel + + # Output layers + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(self.mid_block, h) + else: + h = self.mid_block(h) + + for up_block in self.up_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + h = self._gradient_checkpointing_func(up_block, h) + else: + h = up_block(h) + h = self.norm_out(h) + h = self.nonlinearity(h) + h = self.conv_out(h) + return h + + +class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model for 2D images with spatial tiling support. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + @register_to_config + def __init__( + self, + in_channels: int, + out_channels: int, + latent_channels: int, + block_out_channels: tuple[int, ...], + layers_per_block: int, + spatial_compression_ratio: int, + sample_size: int, + scaling_factor: float = None, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + ) -> None: + # fmt: on + super().__init__() + + self.encoder = HunyuanImageEncoder2D( + in_channels=in_channels, + z_channels=latent_channels, + block_out_channels=block_out_channels, + num_res_blocks=layers_per_block, + spatial_compression_ratio=spatial_compression_ratio, + downsample_match_channel=downsample_match_channel, + ) + + self.decoder = HunyuanImageDecoder2D( + z_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + num_res_blocks=layers_per_block, + spatial_compression_ratio=spatial_compression_ratio, + upsample_match_channel=upsample_match_channel, + ) + + # Tiling and slicing configuration + self.use_slicing = False + self.use_tiling = False + + # Tiling parameters + self.tile_sample_min_size = sample_size + self.tile_latent_min_size = sample_size // spatial_compression_ratio + self.tile_overlap_factor = 0.25 + + def enable_tiling( + self, + tile_sample_min_size: int | None = None, + tile_overlap_factor: float | None = None, + ) -> None: + r""" + Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles + to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to + allow processing larger images. + + Args: + tile_sample_min_size (`int`, *optional*): + The minimum size required for a sample to be separated into tiles across the spatial dimension. + tile_overlap_factor (`float`, *optional*): + The overlap factor required for a latent to be separated into tiles across the spatial dimension. + """ + self.use_tiling = True + self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size + self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor + self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio + + def _encode(self, x: torch.Tensor): + + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): + return self.tiled_encode(x) + + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + + batch_size, num_channels, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode input using spatial tiling strategy. + + Args: + x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W). + + Returns: + `torch.Tensor`: + The latent representation of the encoded images. + """ + _, _, _, height, width = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + rows = [] + for i in range(0, height, overlap_size): + row = [] + for j in range(0, width, overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + moments = torch.cat(result_rows, dim=-2) + + return moments + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + """ + Decode latent using spatial tiling strategy. + + Args: + z (`torch.Tensor`): Latent tensor of shape (B, C, H, W). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, height, width = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + rows = [] + for i in range(0, height, overlap_size): + row = [] + for j in range(0, width, overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2) + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..81957e2feed46fa0cf3ff544f6b86638ae00dd2d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py @@ -0,0 +1,912 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanImageRefinerCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + dilation: int | tuple[int, int, int] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanImageRefinerRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class HunyuanImageRefinerAttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False) + + self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + x = self.norm(x) + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + batch_size, channels, frames, height, width = query.shape + + query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + + x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None) + + # batch_size, 1, frames * height * width, channels + + x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3) + x = self.proj_out(x) + + return x + identity + + +class HunyuanImageRefinerUpsampleDCAE(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3) + + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + @staticmethod + def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w) + + Args: + tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w) + r1: temporal upsampling factor + r2: height upsampling factor + r3: width upsampling factor + """ + b, packed_c, f, h, w = tensor.shape + factor = r1 * r2 * r3 + c = packed_c // factor + + tensor = tensor.view(b, r1, r2, r3, c, f, h, w) + tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3) + return tensor.reshape(b, c, f * r1, h * r2, w * r3) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + if self.add_temporal_upsample: + h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2) + h = h[:, : h.shape[1] // 2] + + # shortcut computation + shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2) + shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1) + + else: + h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2) + return h + shortcut + + +class HunyuanImageRefinerDownsampleDCAE(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + assert out_channels % factor == 0 + # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3) + + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + @staticmethod + def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w) + + This packs spatial/temporal dimensions into channels (opposite of upsample) + """ + b, c, packed_f, packed_h, packed_w = tensor.shape + f, h, w = packed_f // r1, packed_h // r2, packed_w // r3 + + tensor = tensor.view(b, c, f, r1, h, r2, w, r3) + tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6) + return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + if self.add_temporal_downsample: + # h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2) + h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2) + h = torch.cat([h, h], dim=1) + # shortcut computation + # shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2) + shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2) + else: + # h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2) + # shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2) + shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + + return h + shortcut + + +class HunyuanImageRefinerResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False) + self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3) + + self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False) + self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class HunyuanImageRefinerMidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + add_attention: bool = True, + ) -> None: + super().__init__() + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanImageRefinerResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append(HunyuanImageRefinerAttnBlock(in_channels)) + else: + attentions.append(None) + + resnets.append( + HunyuanImageRefinerResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanImageRefinerDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample_out_channels: int | None = None, + add_temporal_downsample: int = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanImageRefinerResnetBlock( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if downsample_out_channels is not None: + self.downsamplers = nn.ModuleList( + [ + HunyuanImageRefinerDownsampleDCAE( + out_channels, + out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanImageRefinerUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + upsample_out_channels: int | None = None, + add_temporal_upsample: bool = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanImageRefinerResnetBlock( + in_channels=input_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if upsample_out_channels is not None: + self.upsamplers = nn.ModuleList( + [ + HunyuanImageRefinerUpsampleDCAE( + out_channels, + out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanImageRefinerEncoder3D(nn.Module): + r""" + 3D vae encoder for HunyuanImageRefiner. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + block_out_channels: tuple[int, ...] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 16, + downsample_match_channel: bool = True, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.group_size = block_out_channels[-1] // self.out_channels + + self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + add_spatial_downsample = i < np.log2(spatial_compression_ratio) + output_channel = block_out_channels[i] + if not add_spatial_downsample: + down_block = HunyuanImageRefinerDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=None, + add_temporal_downsample=False, + ) + input_channel = output_channel + else: + add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio) + downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel + down_block = HunyuanImageRefinerDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + input_channel = downsample_out_channels + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1]) + + self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2) + batch_size, _, frame, height, width = hidden_states.shape + short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states += short_cut + + return hidden_states + + +class HunyuanImageRefinerDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 3, + block_out_channels: tuple[int, ...] = (1024, 1024, 512, 256, 128), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + upsample_match_channel: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.in_channels = in_channels + self.out_channels = out_channels + self.repeat = block_out_channels[0] // self.in_channels + + self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0]) + + # up + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + output_channel = block_out_channels[i] + + add_spatial_upsample = i < np.log2(spatial_compression_ratio) + add_temporal_upsample = i < np.log2(temporal_compression_ratio) + if add_spatial_upsample or add_temporal_upsample: + upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel + up_block = HunyuanImageRefinerUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + input_channel = upsample_out_channels + else: + up_block = HunyuanImageRefinerUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=None, + add_temporal_upsample=False, + ) + input_channel = output_channel + + self.up_blocks.append(up_block) + + # out + self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for + HunyuanImage-2.1 Refiner. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 32, + block_out_channels: tuple[int, ...] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + scaling_factor: float = 1.03682, + ) -> None: + super().__init__() + + self.encoder = HunyuanImageRefinerEncoder3D( + in_channels=in_channels, + out_channels=latent_channels * 2, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + downsample_match_channel=downsample_match_channel, + ) + + self.decoder = HunyuanImageRefinerDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + upsample_match_channel=upsample_match_channel, + ) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + self.tile_overlap_factor = 0.25 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + tile_overlap_factor: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z) + + dec = self.decoder(z) + + return dec + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2 + blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2 + row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6 + row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + + return moments + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + _, _, _, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64 + blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64 + row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192 + row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = z[ + :, + :, + :, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py new file mode 100644 index 0000000000000000000000000000000000000000..2c38b174a1004a892e41d658641d5848383707f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -0,0 +1,945 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15CausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + dilation: int | tuple[int, int, int] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideo15RMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class HunyuanVideo15AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = HunyuanVideo15RMS_norm(in_channels, images=False) + + self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1) + + @staticmethod + def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + """Prepare a causal attention mask for 3D videos. + + Args: + n_frame (int): Number of frames (temporal length). + n_hw (int): Product of height and width. + dtype: Desired mask dtype. + device: Device for the mask. + batch_size (int, optional): If set, expands for batch. + + Returns: + torch.Tensor: Causal attention mask. + """ + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + x = self.norm(x) + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + batch_size, channels, frames, height, width = query.shape + + query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + + attention_mask = self.prepare_causal_attention_mask( + frames, height * width, query.dtype, query.device, batch_size=batch_size + ) + + x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + # batch_size, 1, frames * height * width, channels + + x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3) + x = self.proj_out(x) + + return x + identity + + +class HunyuanVideo15Upsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3) + + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + @staticmethod + def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w) + + Args: + tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w) + r1: temporal upsampling factor + r2: height upsampling factor + r3: width upsampling factor + """ + b, packed_c, f, h, w = tensor.shape + factor = r1 * r2 * r3 + c = packed_c // factor + + tensor = tensor.view(b, r1, r2, r3, c, f, h, w) + tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3) + return tensor.reshape(b, c, f * r1, h * r2, w * r3) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + if self.add_temporal_upsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = h_first[:, : h_first.shape[1] // 2] + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2) + x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1) + + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2) + x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = torch.cat([x_first, x_next], dim=2) + + else: + h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2) + return h + shortcut + + +class HunyuanVideo15Downsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3) + + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + @staticmethod + def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w) + + This packs spatial/temporal dimensions into channels (opposite of upsample) + """ + b, c, packed_f, packed_h, packed_w = tensor.shape + f, h, w = packed_f // r1, packed_h // r2, packed_w // r3 + + tensor = tensor.view(b, c, f, r1, h, r2, w, r3) + tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6) + return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + if self.add_temporal_downsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = torch.cat([h_first, h_first], dim=1) + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2) + B, C, T, H, W = x_first.shape + x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2) + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2) + B, C, T, H, W = x_next.shape + x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + shortcut = torch.cat([x_first, x_next], dim=2) + else: + h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + + return h + shortcut + + +class HunyuanVideo15ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False) + self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3) + + self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False) + self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class HunyuanVideo15MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + add_attention: bool = True, + ) -> None: + super().__init__() + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append(HunyuanVideo15AttnBlock(in_channels)) + else: + attentions.append(None) + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideo15DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample_out_channels: int | None = None, + add_temporal_downsample: int = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if downsample_out_channels is not None: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideo15Downsample( + out_channels, + out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + upsample_out_channels: int | None = None, + add_temporal_upsample: bool = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=input_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if upsample_out_channels is not None: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideo15Upsample( + out_channels, + out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15Encoder3D(nn.Module): + r""" + 3D vae encoder for HunyuanImageRefiner. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + block_out_channels: tuple[int, ...] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 16, + downsample_match_channel: bool = True, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.group_size = block_out_channels[-1] // self.out_channels + + self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + add_spatial_downsample = i < np.log2(spatial_compression_ratio) + output_channel = block_out_channels[i] + if not add_spatial_downsample: + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=None, + add_temporal_downsample=False, + ) + input_channel = output_channel + else: + add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio) + downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + input_channel = downsample_out_channels + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1]) + + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + batch_size, _, frame, height, width = hidden_states.shape + short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states += short_cut + + return hidden_states + + +class HunyuanVideo15Decoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 3, + block_out_channels: tuple[int, ...] = (1024, 1024, 512, 256, 128), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + upsample_match_channel: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.in_channels = in_channels + self.out_channels = out_channels + self.repeat = block_out_channels[0] // self.in_channels + + self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0]) + + # up + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + output_channel = block_out_channels[i] + + add_spatial_upsample = i < np.log2(spatial_compression_ratio) + add_temporal_upsample = i < np.log2(temporal_compression_ratio) + if add_spatial_upsample or add_temporal_upsample: + upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + input_channel = upsample_out_channels + else: + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=None, + add_temporal_upsample=False, + ) + input_channel = output_channel + + self.up_blocks.append(up_block) + + # out + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for + HunyuanVideo-1.5. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 32, + block_out_channels: tuple[int] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + scaling_factor: float = 1.03682, + ) -> None: + super().__init__() + + self.encoder = HunyuanVideo15Encoder3D( + in_channels=in_channels, + out_channels=latent_channels * 2, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + downsample_match_channel=downsample_match_channel, + ) + + self.decoder = HunyuanVideo15Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + upsample_match_channel=upsample_match_channel, + ) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal tile height and width in latent space + self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio + self.tile_overlap_factor = 0.25 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_latent_min_height: int | None = None, + tile_latent_min_width: int | None = None, + tile_overlap_factor: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_latent_min_height (`int`, *optional*): + The minimum height required for a latent to be separated into tiles across the height dimension. + tile_latent_min_width (`int`, *optional*): + The minimum width required for a latent to be separated into tiles across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height + self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width + self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z) + + dec = self.decoder(z) + + return dec + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2 + blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2 + row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6 + row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + + return moments + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + _, _, _, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64 + blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64 + row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192 + row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx.py new file mode 100644 index 0000000000000000000000000000000000000000..a7acc105e9ec01a7ce2080525d5db42b726d7806 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -0,0 +1,1534 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +class LTXVideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + dilation: int | tuple[int, int, int] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class LTXVideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.conv1 = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.dropout = nn.Dropout(dropout) + self.conv2 = LTXVideoCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + self.conv_shortcut = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal + ) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, inputs: torch.Tensor, temb: torch.Tensor | None = None, generator: torch.Generator | None = None + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int | tuple[int, int, int] = 1, + is_causal: bool = True, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: int | tuple[int, int, int] = 1, + is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, + padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +class LTXVideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList( + [ + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ] + ) + + self.conv_out = None + if in_channels != out_channels: + self.conv_out = LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + if self.conv_out is not None: + hidden_states = self.conv_out(hidden_states, temb, generator) + + return hidden_states + + +class LTXVideo095DownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + downsample_type: str = "conv", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTXVideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +class LTXVideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +class LTXVideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTXVideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXVideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + else: + hidden_states = resnet(hidden_states, temb, generator) + + return hidden_states + + +class LTXVideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + down_block_types: tuple[str, ...] = ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False), + layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4), + downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = block_out_channels[0] + + self.conv_in = LTXVideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + # down blocks + is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D" + num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + if not is_ltx_095: + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + else: + output_channel = block_out_channels[i + 1] + + if down_block_types[i] == "LTXVideoDownBlock3D": + down_block = LTXVideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) + elif down_block_types[i] == "LTXVideo095DownBlock3D": + down_block = LTXVideo095DownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + downsample_type=downsample_type[i], + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTXVideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + ) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +class LTXVideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False), + layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: tuple[bool, ...] = (False, False, False, False), + timestep_conditioning: bool = False, + upsample_residual: tuple[bool, ...] = (False, False, False, False), + upsample_factor: tuple[bool, ...] = (1, 1, 1, 1), + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTXVideoCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.mid_block = LTXVideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTXVideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXVideoCausalConv3d( + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + else: + hidden_states = self.mid_block(hidden_states, temb) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX](https://huggingface.co/Lightricks/LTX-Video). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + down_block_types: tuple[str, ...] = ( + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + "LTXVideoDownBlock3D", + ), + decoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4), + decoder_layers_per_block: tuple[int, ...] = (4, 3, 3, 3, 4), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False), + decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, False), + decoder_inject_noise: tuple[bool, ...] = (False, False, False, False, False), + downsample_type: tuple[str, ...] = ("conv", "conv", "conv", "conv"), + upsample_residual: tuple[bool, ...] = (False, False, False, False), + upsample_factor: tuple[int, ...] = (1, 1, 1, 1), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = False, + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTXVideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + ) + self.decoder = LTXVideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_min_num_frames: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + tile_sample_stride_num_frames: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, temb: torch.Tensor | None = None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, return_dict=return_dict) + + dec = self.decoder(z, temb) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.Tensor, temb: torch.Tensor | None = None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: torch.Tensor | None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile) + else: + tile = self.encoder(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: torch.Tensor | None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, return_dict=True).sample + else: + decoded = self.decoder(tile, temb) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: torch.Tensor | None = None, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> torch.Tensor | torch.Tensor: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb) + if not return_dict: + return (dec.sample,) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx2.py new file mode 100644 index 0000000000000000000000000000000000000000..7c04bd715c25d5be28dddda2cd182488a9d84f43 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -0,0 +1,1520 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: int | None = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + dilation: int | tuple[int, int, int] = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = PerChannelRMSNorm() + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm2 = PerChannelRMSNorm() + self.dropout = nn.Dropout(dropout) + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, + inputs: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int | tuple[int, int, int] = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: int | tuple[int, int, int] = 1, + residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, causal=causal) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, causal=causal) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTX2VideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal + + output_channel = out_channels + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + spatial_padding_mode=spatial_padding_mode, + ) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, causal: bool | None = None) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + causal = causal or self.is_causal + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states, causal=causal) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, causal=causal) + + hidden_states = self.mid_block(hidden_states, causal=causal) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTX2VideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_factor: tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + causal: bool | None = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) + else: + hidden_states = self.mid_block(hidden_states, temb, causal=causal) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, causal=causal) + + hidden_states = self.norm_out(hidden_states) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX-2](https://huggingface.co/Lightricks/LTX-2). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024), + layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + decoder_inject_noise: tuple[bool, ...] = (False, False, False, False), + downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_factor: tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTX2VideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + self.decoder = LTX2VideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_min_num_frames: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + tile_sample_stride_num_frames: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor, causal: bool | None = None) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x, causal=causal) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x, causal=causal) + + enc = self.encoder(x, causal=causal) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, causal: bool | None = None, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x, causal=causal) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + temb: torch.Tensor | None = None, + causal: bool | None = None, + return_dict: bool = True, + ) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + dec = self.decoder(z, temb, causal=causal) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + temb: torch.Tensor | None = None, + causal: bool | None = None, + return_dict: bool = True, + ) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb, causal=causal).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor, causal: bool | None = None) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: torch.Tensor | None, causal: bool | None = None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor, causal: bool | None = None) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile, causal=causal) + else: + tile = self.encoder(tile, causal=causal) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: torch.Tensor | None, causal: bool | None = None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample + else: + decoded = self.decoder(tile, temb, causal=causal) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: torch.Tensor | None = None, + sample_posterior: bool = False, + encoder_causal: bool | None = None, + decoder_causal: bool | None = None, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> torch.Tensor | torch.Tensor: + x = sample + posterior = self.encode(x, causal=encoder_causal).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb, causal=decoder_causal) + if not return_dict: + return (dec.sample,) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..f9390dab5b74a15828b793a83ba7acf3884031c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -0,0 +1,803 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_type == "group": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.non_linearity = nn.SiLU() + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + if norm_type == "group": + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.dropout = nn.Dropout(dropout) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioDownsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: str | None = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class LTX2AudioUpsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: str | None = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor: + batch, time, _ = audio_latents.shape + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: tuple[int, ...] | None = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: str | None = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: tuple[int, ...] | None = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: str | None = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + ) -> None: + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + self.up = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis) + curr_res *= 2 + + self.up.insert(0, stage) + + final_block_channels = block_in + + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1) + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape + + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + hidden_features = block(hidden_features, temb=None) + if stage.attn: + hidden_features = stage.attn[block_idx](hidden_features) + + if level != 0 and hasattr(stage, "upsample"): + hidden_features = stage.upsample(hidden_features) + + if self.give_pre_end: + return hidden_features + + hidden = self.norm_out(hidden_features) + hidden = self.non_linearity(hidden) + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + +class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + LTX2 audio VAE for encoding and decoding audio latent representations. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: tuple[int, ...] | None = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: str | None = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + double_z: bool = True, + ) -> None: + super().__init__() + + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}") + + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.ones((base_channels,)) + latents_mean = torch.zeros((base_channels,)) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # TODO: calculate programmatically instead of hardcoding + self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True): + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_mochi.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f831c867b0129c5ec06b3f08206bac65a3bfbf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -0,0 +1,1105 @@ +# Copyright 2025 The Mochi team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MochiChunkedGroupNorm3D(nn.Module): + r""" + Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group + normalization. + + Args: + num_channels (int): Number of channels expected in input + num_groups (int, optional): Number of groups to separate the channels into. Default: 32 + affine (bool, optional): If True, this module has learnable affine parameters. Default: True + chunk_size (int, optional): Size of each chunk for processing. Default: 8 + + """ + + def __init__( + self, + num_channels: int, + num_groups: int = 32, + affine: bool = True, + chunk_size: int = 8, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine) + self.chunk_size = chunk_size + + def forward(self, x: torch.Tensor = None) -> torch.Tensor: + batch_size = x.size(0) + + x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) + output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0) + output = output.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + return output + + +class MochiResnetBlock3D(nn.Module): + r""" + A 3D ResNet block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + act_fn: str = "swish", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(act_fn) + + self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) + self.conv1 = CogVideoXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" + ) + self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels) + self.conv2 = CogVideoXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, pad_mode="replicate" + ) + + def forward( + self, + inputs: torch.Tensor, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1")) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2")) + + hidden_states = hidden_states + inputs + return hidden_states, new_conv_cache + + +class MochiDownBlock3D(nn.Module): + r""" + An downsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + add_attention: bool = True, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + self.conv_in = CogVideoXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(temporal_expansion, spatial_expansion, spatial_expansion), + stride=(temporal_expansion, spatial_expansion, spatial_expansion), + pad_mode="replicate", + ) + + resnets = [] + norms = [] + attentions = [] + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=out_channels)) + if add_attention: + norms.append(MochiChunkedGroupNorm3D(num_channels=out_channels)) + attentions.append( + Attention( + query_dim=out_channels, + heads=out_channels // 32, + dim_head=32, + qk_norm="l2", + is_causal=True, + processor=MochiVaeAttnProcessor2_0(), + ) + ) + else: + norms.append(None) + attentions.append(None) + + self.resnets = nn.ModuleList(resnets) + self.norms = nn.ModuleList(norms) + self.attentions = nn.ModuleList(attentions) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: dict[str, torch.Tensor] | None = None, + chunk_size: int = 2**15, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states, new_conv_cache["conv_in"] = self.conv_in(hidden_states) + + for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, + hidden_states, + conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + if attn is not None: + residual = hidden_states + hidden_states = norm(hidden_states) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() + + # Perform attention in chunks to avoid following error: + # RuntimeError: CUDA error: invalid configuration argument + if hidden_states.size(0) <= chunk_size: + hidden_states = attn(hidden_states) + else: + hidden_states_chunks = [] + for i in range(0, hidden_states.size(0), chunk_size): + hidden_states_chunk = hidden_states[i : i + chunk_size] + hidden_states_chunk = attn(hidden_states_chunk) + hidden_states_chunks.append(hidden_states_chunk) + hidden_states = torch.cat(hidden_states_chunks) + + hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) + + hidden_states = residual + hidden_states + + return hidden_states, new_conv_cache + + +class MochiMidBlock3D(nn.Module): + r""" + A middle block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `3`): + Number of resnet blocks in the block. + """ + + def __init__( + self, + in_channels: int, # 768 + num_layers: int = 3, + add_attention: bool = True, + ): + super().__init__() + + resnets = [] + norms = [] + attentions = [] + + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) + + if add_attention: + norms.append(MochiChunkedGroupNorm3D(num_channels=in_channels)) + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // 32, + dim_head=32, + qk_norm="l2", + is_causal=True, + processor=MochiVaeAttnProcessor2_0(), + ) + ) + else: + norms.append(None) + attentions.append(None) + + self.resnets = nn.ModuleList(resnets) + self.norms = nn.ModuleList(norms) + self.attentions = nn.ModuleList(attentions) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""Forward method of the `MochiMidBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, hidden_states, conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + if attn is not None: + residual = hidden_states + hidden_states = norm(hidden_states) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).flatten(0, 2).contiguous() + hidden_states = attn(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, height, width)).permute(0, 4, 3, 1, 2) + + hidden_states = residual + hidden_states + + return hidden_states, new_conv_cache + + +class MochiUpBlock3D(nn.Module): + r""" + An upsampling block used in the Mochi model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet blocks in the block. + temporal_expansion (`int`, defaults to `2`): + Temporal expansion factor. + spatial_expansion (`int`, defaults to `2`): + Spatial expansion factor. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + temporal_expansion: int = 2, + spatial_expansion: int = 2, + ): + super().__init__() + self.temporal_expansion = temporal_expansion + self.spatial_expansion = spatial_expansion + + resnets = [] + for _ in range(num_layers): + resnets.append(MochiResnetBlock3D(in_channels=in_channels)) + self.resnets = nn.ModuleList(resnets) + + self.proj = nn.Linear(in_channels, out_channels * temporal_expansion * spatial_expansion**2) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + conv_cache: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + r"""Forward method of the `MochiUpBlock3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + for i, resnet in enumerate(self.resnets): + conv_cache_key = f"resnet_{i}" + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + resnet, + hidden_states, + conv_cache.get(conv_cache_key), + ) + else: + hidden_states, new_conv_cache[conv_cache_key] = resnet( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + st = self.temporal_expansion + sh = self.spatial_expansion + sw = self.spatial_expansion + + # Reshape and unpatchify + hidden_states = hidden_states.view(batch_size, -1, st, sh, sw, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + hidden_states = hidden_states.view(batch_size, -1, num_frames * st, height * sh, width * sw) + + return hidden_states, new_conv_cache + + +class FourierFeatures(nn.Module): + def __init__(self, start: int = 6, stop: int = 8, step: int = 1): + super().__init__() + + self.start = start + self.stop = stop + self.step = step + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `FourierFeatures` class.""" + original_dtype = inputs.dtype + inputs = inputs.to(torch.float32) + num_channels = inputs.shape[1] + num_freqs = (self.stop - self.start) // self.step + + freqs = torch.arange(self.start, self.stop, self.step, dtype=inputs.dtype, device=inputs.device) + w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs] + w = w.repeat(num_channels)[None, :, None, None, None] # [1, num_channels * num_freqs, 1, 1, 1] + + # Interleaved repeat of input channels to match w + h = inputs.repeat_interleave( + num_freqs, dim=1, output_size=inputs.shape[1] * num_freqs + ) # [B, C * num_freqs, T, H, W] + # Scale channels by frequency. + h = w * h + + return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype) + + +class MochiEncoder3D(nn.Module): + r""" + The `MochiEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: tuple[int, ...] = (1, 2, 3), + spatial_expansions: tuple[int, ...] = (2, 2, 2), + add_attention_block: tuple[bool, ...] = (False, True, True, True, True), + act_fn: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(act_fn) + + self.fourier_features = FourierFeatures() + self.proj_in = nn.Linear(in_channels, block_out_channels[0]) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[0], num_layers=layers_per_block[0], add_attention=add_attention_block[0] + ) + + down_blocks = [] + for i in range(len(block_out_channels) - 1): + down_block = MochiDownBlock3D( + in_channels=block_out_channels[i], + out_channels=block_out_channels[i + 1], + num_layers=layers_per_block[i + 1], + temporal_expansion=temporal_expansions[i], + spatial_expansion=spatial_expansions[i], + add_attention=add_attention_block[i + 1], + ) + down_blocks.append(down_block) + self.down_blocks = nn.ModuleList(down_blocks) + + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[-1], num_layers=layers_per_block[-1], add_attention=add_attention_block[-1] + ) + self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1]) + self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, conv_cache: dict[str, torch.Tensor] | None = None) -> torch.Tensor: + r"""Forward method of the `MochiEncoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = self.fourier_features(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func( + self.block_in, hidden_states, conv_cache.get("block_in") + ) + + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + down_block, hidden_states, conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache["block_in"] = self.block_in( + hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, down_block in enumerate(self.down_blocks): + conv_cache_key = f"down_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = down_block( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states, new_conv_cache["block_out"] = self.block_out( + hidden_states, conv_cache=conv_cache.get("block_out") + ) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + return hidden_states, new_conv_cache + + +class MochiDecoder3D(nn.Module): + r""" + The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*): + The number of input channels. + out_channels (`int`, *optional*): + The number of output channels. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`): + The number of output channels for each block. + layers_per_block (`tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`): + The number of resnet blocks for each block. + temporal_expansions (`tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`): + The temporal expansion factor for each of the up blocks. + spatial_expansions (`tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`): + The spatial expansion factor for each of the up blocks. + non_linearity (`str`, *optional*, defaults to `"swish"`): + The non-linearity to use in the decoder. + """ + + def __init__( + self, + in_channels: int, # 12 + out_channels: int, # 3 + block_out_channels: tuple[int, ...] = (128, 256, 512, 768), + layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3), + temporal_expansions: tuple[int, ...] = (1, 2, 3), + spatial_expansions: tuple[int, ...] = (2, 2, 2), + act_fn: str = "swish", + ): + super().__init__() + + self.nonlinearity = get_activation(act_fn) + + self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) + self.block_in = MochiMidBlock3D( + in_channels=block_out_channels[-1], + num_layers=layers_per_block[-1], + add_attention=False, + ) + + up_blocks = [] + for i in range(len(block_out_channels) - 1): + up_block = MochiUpBlock3D( + in_channels=block_out_channels[-i - 1], + out_channels=block_out_channels[-i - 2], + num_layers=layers_per_block[-i - 2], + temporal_expansion=temporal_expansions[-i - 1], + spatial_expansion=spatial_expansions[-i - 1], + ) + up_blocks.append(up_block) + self.up_blocks = nn.ModuleList(up_blocks) + + self.block_out = MochiMidBlock3D( + in_channels=block_out_channels[0], + num_layers=layers_per_block[0], + add_attention=False, + ) + self.proj_out = nn.Linear(block_out_channels[0], out_channels) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, conv_cache: dict[str, torch.Tensor] | None = None) -> torch.Tensor: + r"""Forward method of the `MochiDecoder3D` class.""" + + new_conv_cache = {} + conv_cache = conv_cache or {} + + hidden_states = self.conv_in(hidden_states) + + # 1. Mid + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func( + self.block_in, hidden_states, conv_cache.get("block_in") + ) + + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func( + up_block, hidden_states, conv_cache.get(conv_cache_key) + ) + else: + hidden_states, new_conv_cache["block_in"] = self.block_in( + hidden_states, conv_cache=conv_cache.get("block_in") + ) + + for i, up_block in enumerate(self.up_blocks): + conv_cache_key = f"up_block_{i}" + hidden_states, new_conv_cache[conv_cache_key] = up_block( + hidden_states, conv_cache=conv_cache.get(conv_cache_key) + ) + + hidden_states, new_conv_cache["block_out"] = self.block_out( + hidden_states, conv_cache=conv_cache.get("block_out") + ) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 3, 4, 1) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) + + return hidden_states, new_conv_cache + + +class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [Mochi 1 preview](https://github.com/genmoai/models). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["MochiResnetBlock3D"] + + @register_to_config + def __init__( + self, + in_channels: int = 15, + out_channels: int = 3, + encoder_block_out_channels: tuple[int] = (64, 128, 256, 384), + decoder_block_out_channels: tuple[int] = (128, 256, 512, 768), + latent_channels: int = 12, + layers_per_block: tuple[int, ...] = (3, 3, 4, 6, 3), + act_fn: str = "silu", + temporal_expansions: tuple[int, ...] = (1, 2, 3), + spatial_expansions: tuple[int, ...] = (2, 2, 2), + add_attention_block: tuple[bool, ...] = (False, True, True, True, True), + latents_mean: tuple[float, ...] = ( + -0.06730895953510081, + -0.038011381506090416, + -0.07477820912866141, + -0.05565264470995561, + 0.012767231469026969, + -0.04703542746246419, + 0.043896967884726704, + -0.09346305707025976, + -0.09918314763016893, + -0.008729793427399178, + -0.011931556316503654, + -0.0321993391887285, + ), + latents_std: tuple[float, ...] = ( + 0.9263795028493863, + 0.9248894543193766, + 0.9393059390890617, + 0.959253732819592, + 0.8244560132752793, + 0.917259975397747, + 0.9294154431013696, + 1.3720942357788521, + 0.881393668867029, + 0.9168315692124348, + 0.9185249279345552, + 0.9274757570805041, + ), + scaling_factor: float = 1.0, + ): + super().__init__() + + self.encoder = MochiEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=encoder_block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + add_attention_block=add_attention_block, + act_fn=act_fn, + ) + self.decoder = MochiDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + layers_per_block=layers_per_block, + temporal_expansions=temporal_expansions, + spatial_expansions=spatial_expansions, + act_fn=act_fn, + ) + + self.spatial_compression_ratio = functools.reduce(lambda x, y: x * y, spatial_expansions, 1) + self.temporal_compression_ratio = functools.reduce(lambda x, y: x * y, temporal_expansions, 1) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be used to determine how the number of output frames in the final decoded video. To maintain consistency with + # the original implementation, this defaults to `True`. + # - Original implementation (drop_last_temporal_frames=True): + # Output frames = (latent_frames - 1) * temporal_compression_ratio + 1 + # - Without dropping additional temporal upscaled frames (drop_last_temporal_frames=False): + # Output frames = latent_frames * temporal_compression_ratio + # The latter case is useful for frame packing and some training/finetuning scenarios where the additional. + self.drop_last_temporal_frames = True + + # This can be configured based on the amount of GPU memory available. + # `12` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 12 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def _enable_framewise_encoding(self): + r""" + Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the + oneshot encoding implementation without current latent replicate padding. + + Warning: Framewise encoding may not work as expected due to the causal attention layers. If you enable + framewise encoding, encode a video, and try to decode it, there will be noticeable jittering effect. + """ + self.use_framewise_encoding = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.pad_mode = "constant" + + def _enable_framewise_decoding(self): + r""" + Enables the framewise VAE decoding implementation with past latent padding. By default, Diffusers uses the + oneshot decoding implementation without current latent replicate padding. + """ + self.use_framewise_decoding = True + for name, module in self.named_modules(): + if isinstance(module, CogVideoXCausalConv3d): + module.pad_mode = "constant" + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + if self.use_framewise_encoding: + raise NotImplementedError( + "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " + "As intermediate frames are not independent from each other, they cannot be encoded frame-wise." + ) + else: + enc, _ = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if self.use_framewise_decoding: + conv_cache = None + dec = [] + + for i in range(0, num_frames, self.num_latent_frames_batch_size): + z_intermediate = z[:, :, i : i + self.num_latent_frames_batch_size] + z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache) + dec.append(z_intermediate) + + dec = torch.cat(dec, dim=2) + else: + dec, _ = self.decoder(z) + + if self.drop_last_temporal_frames and dec.size(2) >= self.temporal_compression_ratio: + dec = dec[:, :, self.temporal_compression_ratio - 1 :] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + if self.use_framewise_encoding: + raise NotImplementedError( + "Frame-wise encoding does not work with the Mochi VAE Encoder due to the presence of attention layers. " + "As intermediate frames are not independent from each other, they cannot be encoded frame-wise." + ) + else: + time, _ = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + if self.use_framewise_decoding: + time = [] + conv_cache = None + + for k in range(0, num_frames, self.num_latent_frames_batch_size): + tile = z[ + :, + :, + k : k + self.num_latent_frames_batch_size, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile) + + time = torch.cat(time, dim=2) + else: + time, _ = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + + if self.drop_last_temporal_frames and time.size(2) >= self.temporal_compression_ratio: + time = time[:, :, self.temporal_compression_ratio - 1 :] + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> torch.Tensor | torch.Tensor: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ca0f42a2729c4ec709060c47015179c8397972 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -0,0 +1,1049 @@ +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - Paper: https://huggingface.co/papers/2503.20314 + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + input_channels=3, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + input_channels=3, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: list[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: list[float] = [], + temperal_downsample: list[bool] = [False, True, True], + dropout: float = 0.0, + input_channels: int = 3, + latents_mean: list[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + latents_std: list[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..95d4b0b7b535b6e46ef39fbf038f8ed8dc874033 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -0,0 +1,303 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..attention import AttentionMixin +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class TemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + block_out_channels: tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + self.mid_block = MidBlockTemporalDecoder( + num_layers=self.layers_per_block, + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + attention_head_dim=block_out_channels[-1], + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + up_block = UpBlockTemporalDecoder( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = torch.nn.Conv2d( + in_channels=block_out_channels[0], + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + + conv_out_kernel_size = (3, 1, 1) + padding = [int(k // 2) for k in conv_out_kernel_size] + self.time_conv_out = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=conv_out_kernel_size, + padding=padding, + ) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + image_only_indicator: torch.Tensor, + num_frames: int = 1, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func( + self.mid_block, + sample, + image_only_indicator, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func( + up_block, + sample, + image_only_indicator, + ) + else: + # middle + sample = self.mid_block(sample, image_only_indicator=image_only_indicator) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, image_only_indicator=image_only_indicator) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + batch_frames, channels, height, width = sample.shape + batch_size = batch_frames // num_frames + sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + sample = self.time_conv_out(sample) + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + + return sample + + +class AutoencoderKLTemporalDecoder(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + tuple of downsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of block output channels. + layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str] = ("DownEncoderBlock2D",), + block_out_channels: tuple[int] = (64,), + layers_per_block: int = 1, + latent_channels: int = 4, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = TemporalDecoder( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain + tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + num_frames: int, + return_dict: bool = True, + ) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + batch_size = z.shape[0] // num_frames + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device) + decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + num_frames: int = 1, + ) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + + dec = self.decode(z, num_frames=num_frames).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5d2efe642f78d4ea39adf2b3ddde7c135f4172 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -0,0 +1,1423 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class WanCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class WanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class WanUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class WanResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class WanAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + + return x + + +class WanResidualDownBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class WanEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + in_channels: int = 3, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + + return x + + +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x + + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class WanDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + + # init block + self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0 and not is_residual: + # wan vae 2.1 + in_dim = in_dim // 2 + + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None + upsample_mode = None + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + # Create and add the upsampling block + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag=up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +def patchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + # x shape: [batch_size, channels, frames, height, width] + batch_size, channels, frames, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] + x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] + batch_size, c_patches, frames, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + # Reshape to [b, c, patch_size, patch_size, f, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) + + # Rearrange to [b, c, f, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous() + x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) + + return x + + +class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] + # keys toignore when AlignDeviceHook moves inputs/outputs between devices + # these are shared mutable state modified in-place + _skip_keys = ["feat_cache", "feat_idx"] + + @register_to_config + def __init__( + self, + base_dim: int = 96, + decoder_base_dim: int | None = None, + z_dim: int = 16, + dim_mult: list[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: list[float] = [], + temperal_downsample: list[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: list[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: list[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: int | None = None, + scale_factor_temporal: int | None = 4, + scale_factor_spatial: int | None = 8, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + if decoder_base_dim is None: + decoder_base_dim = base_dim + + self.encoder = WanEncoder3d( + in_channels=in_channels, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + is_residual=is_residual, + ) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + dim=decoder_base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=dropout, + out_channels=out_channels, + is_residual=is_residual, + ) + + self.spatial_compression_ratio = scale_factor_spatial + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def clear_cache(self): + # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call + self._conv_num = self._cached_conv_counts["decoder"] + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = self._cached_conv_counts["encoder"] + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True + ) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) + + out = torch.clamp(out, min=-1.0, max=1.0) + + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + _, _, num_frames, height, width = x.shape + encode_spatial_compression_ratio = self.spatial_compression_ratio + if self.config.patch_size is not None: + assert encode_spatial_compression_ratio % self.config.patch_size == 0 + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + + latent_height = height // encode_spatial_compression_ratio + latent_width = width // encode_spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = self.tile_sample_min_height - tile_sample_stride_height + blend_width = self.tile_sample_min_width - tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_rae.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_rae.py new file mode 100644 index 0000000000000000000000000000000000000000..58ea66f8d18d44652e8c0822ca54285c68ad5411 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_rae.py @@ -0,0 +1,689 @@ +# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import sqrt +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.import_utils import is_transformers_available +from ...utils.torch_utils import randn_tensor + + +if is_transformers_available(): + from transformers import ( + Dinov2WithRegistersConfig, + Dinov2WithRegistersModel, + SiglipVisionConfig, + SiglipVisionModel, + ViTMAEConfig, + ViTMAEModel, + ) + +from ..activations import get_activation +from ..attention import AttentionMixin +from ..attention_processor import Attention +from ..embeddings import get_2d_sincos_pos_embed +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput + + +logger = logging.get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Per-encoder forward functions +# --------------------------------------------------------------------------- +# Each function takes the raw transformers model + images and returns patch +# tokens of shape (B, N, C), stripping CLS / register tokens as needed. + + +def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: + outputs = model(images, output_hidden_states=True) + unused_token_num = 5 # 1 CLS + 4 register tokens + return outputs.last_hidden_state[:, unused_token_num:] + + +def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor: + outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True) + return outputs.last_hidden_state + + +def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor: + h, w = images.shape[2], images.shape[3] + patch_num = int(h * w // patch_size**2) + if patch_num * patch_size**2 != h * w: + raise ValueError("Image size should be divisible by patch size.") + noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype) + outputs = model(images, noise, interpolate_pos_encoding=True) + return outputs.last_hidden_state[:, 1:] # remove cls token + + +# --------------------------------------------------------------------------- +# Encoder construction helpers +# --------------------------------------------------------------------------- + + +def _build_encoder( + encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64 +) -> nn.Module: + """Build a frozen encoder from config (no pretrained download).""" + num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64 + + if encoder_type == "dinov2": + config = Dinov2WithRegistersConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=518, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + ) + model = Dinov2WithRegistersModel(config) + # RAE strips the final layernorm affine params (identity LN). Remove them from + # the architecture so `from_pretrained` doesn't leave them on the meta device. + model.layernorm.weight = None + model.layernorm.bias = None + elif encoder_type == "siglip2": + config = SiglipVisionConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=256, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + ) + model = SiglipVisionModel(config) + # See dinov2 comment above. + model.vision_model.post_layernorm.weight = None + model.vision_model.post_layernorm.bias = None + elif encoder_type == "mae": + config = ViTMAEConfig( + hidden_size=hidden_size, + patch_size=patch_size, + image_size=224, + num_attention_heads=num_attention_heads, + num_hidden_layers=num_hidden_layers, + mask_ratio=0.0, + ) + model = ViTMAEModel(config) + # See dinov2 comment above. + model.layernorm.weight = None + model.layernorm.bias = None + else: + raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae") + + model.requires_grad_(False) + return model + + +_ENCODER_FORWARD_FNS = { + "dinov2": _dinov2_encoder_forward, + "siglip2": _siglip2_encoder_forward, + "mae": _mae_encoder_forward, +} + + +@dataclass +class RAEDecoderOutput(BaseOutput): + """ + Output of `RAEDecoder`. + + Args: + logits (`torch.Tensor`): + Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`. + """ + + logits: torch.Tensor + + +class ViTMAEIntermediate(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"): + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = get_activation(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class ViTMAEOutput(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0): + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class ViTMAELayer(nn.Module): + """ + This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block). + """ + + def __init__( + self, + *, + hidden_size: int, + num_attention_heads: int, + intermediate_size: int, + qkv_bias: bool = True, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + hidden_act: str = "gelu", + ): + super().__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}" + ) + self.attention = Attention( + query_dim=hidden_size, + heads=num_attention_heads, + dim_head=hidden_size // num_attention_heads, + dropout=attention_probs_dropout_prob, + bias=qkv_bias, + ) + self.intermediate = ViTMAEIntermediate( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act + ) + self.output = ViTMAEOutput( + hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob + ) + self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + attention_output = self.attention(self.layernorm_before(hidden_states)) + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output, hidden_states) + return layer_output + + +class RAEDecoder(nn.Module): + """ + Decoder implementation ported from RAE-main to keep checkpoint compatibility. + + Key attributes (must match checkpoint keys): + - decoder_embed + - decoder_pos_embed + - decoder_layers + - decoder_norm + - decoder_pred + - trainable_cls_token + """ + + def __init__( + self, + hidden_size: int = 768, + decoder_hidden_size: int = 512, + decoder_num_hidden_layers: int = 8, + decoder_num_attention_heads: int = 16, + decoder_intermediate_size: int = 2048, + num_patches: int = 256, + patch_size: int = 16, + num_channels: int = 3, + image_size: int = 256, + qkv_bias: bool = True, + layer_norm_eps: float = 1e-12, + hidden_dropout_prob: float = 0.0, + attention_probs_dropout_prob: float = 0.0, + hidden_act: str = "gelu", + ): + super().__init__() + self.decoder_hidden_size = decoder_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.image_size = image_size + self.num_patches = num_patches + + self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True) + grid_size = int(num_patches**0.5) + pos_embed = get_2d_sincos_pos_embed( + decoder_hidden_size, grid_size, cls_token=True, extra_tokens=1, output_type="pt" + ) + self.register_buffer("decoder_pos_embed", pos_embed.unsqueeze(0).float(), persistent=False) + + self.decoder_layers = nn.ModuleList( + [ + ViTMAELayer( + hidden_size=decoder_hidden_size, + num_attention_heads=decoder_num_attention_heads, + intermediate_size=decoder_intermediate_size, + qkv_bias=qkv_bias, + layer_norm_eps=layer_norm_eps, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + hidden_act=hidden_act, + ) + for _ in range(decoder_num_hidden_layers) + ] + ) + + self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) + self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True) + self.gradient_checkpointing = False + + self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: + embeddings_positions = embeddings.shape[1] - 1 + num_positions = self.decoder_pos_embed.shape[1] - 1 + + class_pos_embed = self.decoder_pos_embed[:, 0, :] + patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + dim = self.decoder_pos_embed.shape[-1] + + patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2) + patch_pos_embed = F.interpolate( + patch_pos_embed, + scale_factor=(1, embeddings_positions / num_positions), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor: + b, l, c = x.shape + if l == self.num_patches: + return x + h = w = int(l**0.5) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5)) + x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False) + x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c) + return x + + def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None): + patch_size, num_channels = self.patch_size, self.num_channels + original_image_size = ( + original_image_size if original_image_size is not None else (self.image_size, self.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size + if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: + raise ValueError( + f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" + ) + + batch_size = patchified_pixel_values.shape[0] + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, + num_patches_h, + num_patches_w, + patch_size, + patch_size, + num_channels, + ) + patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) + pixel_values = patchified_pixel_values.reshape( + batch_size, + num_channels, + num_patches_h * patch_size, + num_patches_w * patch_size, + ) + return pixel_values + + def forward( + self, + hidden_states: torch.Tensor, + *, + interpolate_pos_encoding: bool = False, + drop_cls_token: bool = False, + return_dict: bool = True, + ) -> RAEDecoderOutput | tuple[torch.Tensor]: + x = self.decoder_embed(hidden_states) + if drop_cls_token: + x_ = x[:, 1:, :] + x_ = self.interpolate_latent(x_) + else: + x_ = self.interpolate_latent(x) + + cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1) + x = torch.cat([cls_token, x_], dim=1) + + if interpolate_pos_encoding: + if not drop_cls_token: + raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True") + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed + + hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype) + + for layer_module in self.decoder_layers: + hidden_states = layer_module(hidden_states) + + hidden_states = self.decoder_norm(hidden_states) + logits = self.decoder_pred(hidden_states) + logits = logits[:, 1:, :] + + if not return_dict: + return (logits,) + return RAEDecoderOutput(logits=logits) + + +class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): + r""" + Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images. + + This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct + images from learned representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Args: + encoder_type (`str`, *optional*, defaults to `"dinov2"`): + Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`. + encoder_hidden_size (`int`, *optional*, defaults to `768`): + Hidden size of the encoder model. + encoder_patch_size (`int`, *optional*, defaults to `14`): + Patch size of the encoder model. + encoder_num_hidden_layers (`int`, *optional*, defaults to `12`): + Number of hidden layers in the encoder model. + patch_size (`int`, *optional*, defaults to `16`): + Decoder patch size (used for unpatchify and decoder head). + encoder_input_size (`int`, *optional*, defaults to `224`): + Input size expected by the encoder. + image_size (`int`, *optional*): + Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like + RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size // + encoder_patch_size) ** 2`. + num_channels (`int`, *optional*, defaults to `3`): + Number of input/output channels. + encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + Channel-wise mean for encoder input normalization (ImageNet defaults). + encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + Channel-wise std for encoder input normalization (ImageNet defaults). + latents_mean (`list` or `tuple`, *optional*): + Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable + lists. + latents_std (`list` or `tuple`, *optional*): + Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to + config-serializable lists. + noise_tau (`float`, *optional*, defaults to `0.0`): + Noise level for training (adds noise to latents during training). + reshape_to_2d (`bool`, *optional*, defaults to `True`): + Whether to reshape latents to 2D (B, C, H, W) format. + use_encoder_loss (`bool`, *optional*, defaults to `False`): + Whether to use encoder hidden states in the loss (for advanced training). + """ + + # NOTE: gradient checkpointing is not wired up for this model yet. + _supports_gradient_checkpointing = False + _no_split_modules = ["ViTMAELayer"] + _keys_to_ignore_on_load_unexpected = ["decoder.decoder_pos_embed"] + + @register_to_config + def __init__( + self, + encoder_type: str = "dinov2", + encoder_hidden_size: int = 768, + encoder_patch_size: int = 14, + encoder_num_hidden_layers: int = 12, + decoder_hidden_size: int = 512, + decoder_num_hidden_layers: int = 8, + decoder_num_attention_heads: int = 16, + decoder_intermediate_size: int = 2048, + patch_size: int = 16, + encoder_input_size: int = 224, + image_size: int | None = None, + num_channels: int = 3, + encoder_norm_mean: list | None = None, + encoder_norm_std: list | None = None, + latents_mean: list | tuple | torch.Tensor | None = None, + latents_std: list | tuple | torch.Tensor | None = None, + noise_tau: float = 0.0, + reshape_to_2d: bool = True, + use_encoder_loss: bool = False, + scaling_factor: float = 1.0, + ): + super().__init__() + + if encoder_type not in _ENCODER_FORWARD_FNS: + raise ValueError( + f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}" + ) + + def _to_config_compatible(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.detach().cpu().tolist() + if isinstance(value, tuple): + return [_to_config_compatible(v) for v in value] + if isinstance(value, list): + return [_to_config_compatible(v) for v in value] + return value + + def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + return value.detach().clone() + return torch.tensor(value, dtype=torch.float32) + + latents_std_tensor = _as_optional_tensor(latents_std) + + # Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors. + self.register_to_config( + latents_mean=_to_config_compatible(latents_mean), + latents_std=_to_config_compatible(latents_std), + ) + + self.encoder_input_size = encoder_input_size + self.noise_tau = float(noise_tau) + self.reshape_to_2d = bool(reshape_to_2d) + self.use_encoder_loss = bool(use_encoder_loss) + + # Validate early, before building the (potentially large) encoder/decoder. + encoder_patch_size = int(encoder_patch_size) + if self.encoder_input_size % encoder_patch_size != 0: + raise ValueError( + f"encoder_input_size={self.encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}." + ) + decoder_patch_size = int(patch_size) + if decoder_patch_size <= 0: + raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).") + + # Frozen representation encoder (built from config, no downloads) + self.encoder: nn.Module = _build_encoder( + encoder_type=encoder_type, + hidden_size=encoder_hidden_size, + patch_size=encoder_patch_size, + num_hidden_layers=encoder_num_hidden_layers, + ) + self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type] + num_patches = (self.encoder_input_size // encoder_patch_size) ** 2 + + grid = int(sqrt(num_patches)) + if grid * grid != num_patches: + raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.") + + derived_image_size = decoder_patch_size * grid + if image_size is None: + image_size = derived_image_size + else: + image_size = int(image_size) + if image_size != derived_image_size: + raise ValueError( + f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} " + f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}." + ) + + # Encoder input normalization stats (ImageNet defaults) + if encoder_norm_mean is None: + encoder_norm_mean = [0.485, 0.456, 0.406] + if encoder_norm_std is None: + encoder_norm_std = [0.229, 0.224, 0.225] + encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1) + encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1) + + self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True) + self.register_buffer("encoder_std", encoder_std_tensor, persistent=True) + + # Latent normalization buffers (defaults are no-ops; actual values come from checkpoint) + latents_mean_tensor = _as_optional_tensor(latents_mean) + if latents_mean_tensor is None: + latents_mean_tensor = torch.zeros(1) + self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True) + + if latents_std_tensor is None: + latents_std_tensor = torch.ones(1) + self.register_buffer("_latents_std", latents_std_tensor, persistent=True) + + # ViT-MAE style decoder + self.decoder = RAEDecoder( + hidden_size=int(encoder_hidden_size), + decoder_hidden_size=int(decoder_hidden_size), + decoder_num_hidden_layers=int(decoder_num_hidden_layers), + decoder_num_attention_heads=int(decoder_num_attention_heads), + decoder_intermediate_size=int(decoder_intermediate_size), + num_patches=int(num_patches), + patch_size=int(decoder_patch_size), + num_channels=int(num_channels), + image_size=int(image_size), + ) + self.num_patches = int(num_patches) + self.decoder_patch_size = int(decoder_patch_size) + self.decoder_image_size = int(image_size) + + # Slicing support (batch dimension) similar to other diffusers autoencoders + self.use_slicing = False + + def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: + # Per-sample random sigma in [0, noise_tau] + noise_sigma = self.noise_tau * torch.rand( + (x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator + ) + return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype) + + def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.shape + if h != self.encoder_input_size or w != self.encoder_input_size: + x = F.interpolate( + x, size=(self.encoder_input_size, self.encoder_input_size), mode="bicubic", align_corners=False + ) + mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) + std = self.encoder_std.to(device=x.device, dtype=x.dtype) + return (x - mean) / std + + def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor: + mean = self.encoder_mean.to(device=x.device, dtype=x.dtype) + std = self.encoder_std.to(device=x.device, dtype=x.dtype) + return x * std + mean + + def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor: + latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) + latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) + return (z - latents_mean) / (latents_std + 1e-5) + + def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor: + latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype) + latents_std = self._latents_std.to(device=z.device, dtype=z.dtype) + return z * (latents_std + 1e-5) + latents_mean + + def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor: + x = self._resize_and_normalize(x) + + if self.config.encoder_type == "mae": + tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size) + else: + tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C) + + if self.training and self.noise_tau > 0: + tokens = self._noising(tokens, generator=generator) + + if self.reshape_to_2d: + b, n, c = tokens.shape + side = int(sqrt(n)) + if side * side != n: + raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.") + z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w) + else: + z = tokens + + z = self._normalize_latents(z) + + # Follow diffusers convention: optionally scale latents for diffusion + if self.config.scaling_factor != 1.0: + z = z * self.config.scaling_factor + + return z + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None + ) -> EncoderOutput | tuple[torch.Tensor]: + if self.use_slicing and x.shape[0] > 1: + latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0) + else: + latents = self._encode(x, generator=generator) + + if not return_dict: + return (latents,) + return EncoderOutput(latent=latents) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # Undo scaling factor if applied at encode time + if self.config.scaling_factor != 1.0: + z = z / self.config.scaling_factor + + z = self._denormalize_latents(z) + + if self.reshape_to_2d: + b, c, h, w = z.shape + tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C) + else: + tokens = z + + logits = self.decoder(tokens, return_dict=True).logits + x_rec = self.decoder.unpatchify(logits) + x_rec = self._denormalize_image(x_rec) + return x_rec.to(device=z.device) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def forward( + self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None + ) -> DecoderOutput | tuple[torch.Tensor]: + latents = self.encode(sample, return_dict=False, generator=generator)[0] + decoded = self.decode(latents, return_dict=False)[0] + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_tiny.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..14bc57ca90223a8ab574e9c1d2afd0d1019c36c4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_tiny.py @@ -0,0 +1,316 @@ +# Copyright 2025 Ollin Boer Bohan and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny + + +@dataclass +class AutoencoderTinyOutput(BaseOutput): + """ + Output of AutoencoderTiny encoding method. + + Args: + latents (`torch.Tensor`): Encoded outputs of the `Encoder`. + + """ + + latents: torch.Tensor + + +class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. + + [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Parameters: + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + encoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + tuple of integers representing the number of output channels for each encoder block. The length of the + tuple should be equal to the number of encoder blocks. + decoder_block_out_channels (`tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + tuple of integers representing the number of output channels for each decoder block. The length of the + tuple should be equal to the number of decoder blocks. + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function to be used throughout the model. + latent_channels (`int`, *optional*, defaults to 4): + Number of channels in the latent representation. The latent space acts as a compressed representation of + the input image. + upsampling_scaling_factor (`int`, *optional*, defaults to 2): + Scaling factor for upsampling in the decoder. It determines the size of the output image during the + upsampling process. + num_encoder_blocks (`tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): + tuple of integers representing the number of encoder blocks at each stage of the encoding process. The + length of the tuple should be equal to the number of stages in the encoder. Each stage has a different + number of encoder blocks. + num_decoder_blocks (`tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): + tuple of integers representing the number of decoder blocks at each stage of the decoding process. The + length of the tuple should be equal to the number of stages in the decoder. Each stage has a different + number of decoder blocks. + latent_magnitude (`float`, *optional*, defaults to 3.0): + Magnitude of the latent representation. This parameter scales the latent representation values to control + the extent of information preservation. + latent_shift (float, *optional*, defaults to 0.5): + Shift applied to the latent representation. This parameter controls the center of the latent space. + scaling_factor (`float`, *optional*, defaults to 1.0): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. For this + Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default. + force_upcast (`bool`, *optional*, default to `False`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision, in which case + `force_upcast` can be set to `False` (see this fp16-friendly + [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + encoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64), + decoder_block_out_channels: tuple[int, ...] = (64, 64, 64, 64), + act_fn: str = "relu", + upsample_fn: str = "nearest", + latent_channels: int = 4, + upsampling_scaling_factor: int = 2, + num_encoder_blocks: tuple[int, ...] = (1, 3, 3, 3), + num_decoder_blocks: tuple[int, ...] = (3, 3, 3, 1), + latent_magnitude: int = 3, + latent_shift: float = 0.5, + force_upcast: bool = False, + scaling_factor: float = 1.0, + shift_factor: float = 0.0, + ): + super().__init__() + + if len(encoder_block_out_channels) != len(num_encoder_blocks): + raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") + if len(decoder_block_out_channels) != len(num_decoder_blocks): + raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") + + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ) + + self.decoder = DecoderTiny( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + upsample_fn=upsample_fn, + ) + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2**out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + + def scale_latents(self, x: torch.Tensor) -> torch.Tensor: + """raw latents -> [0, 1]""" + return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) + + def unscale_latents(self, x: torch.Tensor) -> torch.Tensor: + """[0, 1] -> raw latents""" + return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: Encoded batch of images. + """ + # scale of encoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_sample_min_size + + # number of pixels to blend and to traverse between tile + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] + tile = self.encoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = blend_mask_i * blend_mask_j + tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: Encoded batch of images. + """ + # scale of decoder output relative to input + sf = self.spatial_scale_factor + tile_size = self.tile_latent_min_size + + # number of pixels to blend and to traverse between tiles + blend_size = int(tile_size * self.tile_overlap_factor) + traverse_size = tile_size - blend_size + + # tiles index (up/left) + ti = range(0, x.shape[-2], traverse_size) + tj = range(0, x.shape[-1], traverse_size) + + # mask for blending + blend_masks = torch.stack( + torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") + ) + blend_masks = blend_masks.clamp(0, 1).to(x.device) + + # output array + out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) + for i in ti: + for j in tj: + tile_in = x[..., i : i + tile_size, j : j + tile_size] + # tile result + tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] + tile = self.decoder(tile_in) + h, w = tile.shape[-2], tile.shape[-1] + # blend tile result into output + blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] + blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] + blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w] + tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) + return out + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderTinyOutput | tuple[torch.Tensor]: + if self.use_slicing and x.shape[0] > 1: + output = [ + self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1) + ] + output = torch.cat(output) + else: + output = self._tiled_encode(x) if self.use_tiling else self.encoder(x) + + if not return_dict: + return (output,) + + return AutoencoderTinyOutput(latents=output) + + @apply_forward_hook + def decode( + self, x: torch.Tensor, generator: torch.Generator | None = None, return_dict: bool = True + ) -> DecoderOutput | tuple[torch.Tensor]: + if self.use_slicing and x.shape[0] > 1: + output = [ + self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1) + ] + output = torch.cat(output) + else: + output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) + + if not return_dict: + return (output,) + + return DecoderOutput(sample=output) + + def forward( + self, + sample: torch.Tensor, + return_dict: bool = True, + ) -> DecoderOutput | tuple[torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + enc = self.encode(sample).latents + + # scale latents to be in [0, 1], then quantize latents to a byte tensor, + # as if we were storing the latents in an RGBA uint8 image. + scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() + + # unquantize latents back into [0, 1], then unscale latents back to their original range, + # as if we were loading the latents from an RGBA uint8 image. + unscaled_enc = self.unscale_latents(scaled_enc / 255.0) + + dec = self.decode(unscaled_enc).sample + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/consistency_decoder_vae.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/consistency_decoder_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..2d53b745e0fbc80275c3cef53b635cf726820da5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -0,0 +1,368 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...schedulers import ConsistencyDecoderScheduler +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.torch_utils import randn_tensor +from ..attention import AttentionMixin +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..modeling_utils import ModelMixin +from ..unets.unet_2d import UNet2DModel +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class ConsistencyDecoderVAEOutput(BaseOutput): + """ + Output of encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class ConsistencyDecoderVAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin): + r""" + The consistency decoder used with DALL-E 3. + + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE + + >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16) + >>> pipe = StableDiffusionPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16 + ... ).to("cuda") + + >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0] + >>> image + ``` + """ + + _supports_group_offloading = False + + @register_to_config + def __init__( + self, + scaling_factor: float = 0.18215, + latent_channels: int = 4, + sample_size: int = 32, + encoder_act_fn: str = "silu", + encoder_block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + encoder_double_z: bool = True, + encoder_down_block_types: tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + encoder_in_channels: int = 3, + encoder_layers_per_block: int = 2, + encoder_norm_num_groups: int = 32, + encoder_out_channels: int = 4, + decoder_add_attention: bool = False, + decoder_block_out_channels: tuple[int, ...] = (320, 640, 1024, 1024), + decoder_down_block_types: tuple[str, ...] = ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + decoder_downsample_padding: int = 1, + decoder_in_channels: int = 7, + decoder_layers_per_block: int = 3, + decoder_norm_eps: float = 1e-05, + decoder_norm_num_groups: int = 32, + decoder_num_train_timesteps: int = 1024, + decoder_out_channels: int = 6, + decoder_resnet_time_scale_shift: str = "scale_shift", + decoder_time_embedding_type: str = "learned", + decoder_up_block_types: tuple[str, ...] = ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + ): + super().__init__() + self.encoder = Encoder( + act_fn=encoder_act_fn, + block_out_channels=encoder_block_out_channels, + double_z=encoder_double_z, + down_block_types=encoder_down_block_types, + in_channels=encoder_in_channels, + layers_per_block=encoder_layers_per_block, + norm_num_groups=encoder_norm_num_groups, + out_channels=encoder_out_channels, + ) + + self.decoder_unet = UNet2DModel( + add_attention=decoder_add_attention, + block_out_channels=decoder_block_out_channels, + down_block_types=decoder_down_block_types, + downsample_padding=decoder_downsample_padding, + in_channels=decoder_in_channels, + layers_per_block=decoder_layers_per_block, + norm_eps=decoder_norm_eps, + norm_num_groups=decoder_norm_num_groups, + num_train_timesteps=decoder_num_train_timesteps, + out_channels=decoder_out_channels, + resnet_time_scale_shift=decoder_resnet_time_scale_shift, + time_embedding_type=decoder_time_embedding_type, + up_block_types=decoder_up_block_types, + ) + self.decoder_scheduler = ConsistencyDecoderScheduler() + self.register_to_config(block_out_channels=encoder_block_out_channels) + self.register_to_config(force_upcast=False) + self.register_buffer( + "means", + torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], + persistent=False, + ) + self.register_buffer( + "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> ConsistencyDecoderVAEOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a + plain `tuple` is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return ConsistencyDecoderVAEOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + num_inference_steps: int = 2, + ) -> DecoderOutput | tuple[torch.Tensor]: + """ + Decodes the input latent vector `z` using the consistency decoder VAE model. + + Args: + z (torch.Tensor): The input latent vector. + generator (torch.Generator | None): The random number generator. Default is None. + return_dict (bool): Whether to return the output as a dictionary. Default is True. + num_inference_steps (int): The number of inference steps. Default is 2. + + Returns: + DecoderOutput | tuple[torch.Tensor]: The decoded output. + + """ + z = (z * self.config.scaling_factor - self.means) / self.stds + + scale_factor = 2 ** (len(self.config.block_out_channels) - 1) + z = F.interpolate(z, mode="nearest", scale_factor=scale_factor) + + batch_size, _, height, width = z.shape + + self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device) + + x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor( + (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device + ) + + for t in self.decoder_scheduler.timesteps: + model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1) + model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :] + prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample + x_t = prev_sample + + x_0 = x_t + + if not return_dict: + return (x_0,) + + return DecoderOutput(sample=x_0) + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput | tuple: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + instead of a plain tuple. + + Returns: + [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] + is returned, otherwise a plain `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return ConsistencyDecoderVAEOutput(latent_dist=posterior) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | tuple[torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*, defaults to `None`): + Generator to use for sampling. + + Returns: + [`DecoderOutput`] or `tuple`: + If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, generator=generator).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/vae.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..042cb5c10021d631c60e99fae0d3b33add57f5a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/vae.py @@ -0,0 +1,927 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn as nn + +from ...utils import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..activations import get_activation +from ..attention_processor import SpatialNorm +from ..unets.unet_2d_blocks import ( + AutoencoderTinyBlock, + UNetMidBlock2D, + get_down_block, + get_up_block, +) + + +@dataclass +class EncoderOutput(BaseOutput): + r""" + Output of encoding method. + + Args: + latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`): + The encoded latent. + """ + + latent: torch.Tensor + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + commit_loss: torch.FloatTensor | None = None + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # down + for down_block in self.down_blocks: + sample = self._gradient_checkpointing_func(down_block, sample) + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.Tensor, + latent_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class UpSample(nn.Module): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `UpSample` class.""" + x = torch.relu(x) + x = self.deconv(x) + return x + + +class MaskConditionEncoder(nn.Module): + """ + used in AsymmetricAutoencoderKL + """ + + def __init__( + self, + in_ch: int, + out_ch: int = 192, + res_ch: int = 768, + stride: int = 16, + ) -> None: + super().__init__() + + channels = [] + while stride > 1: + stride = stride // 2 + in_ch_ = out_ch * 2 + if out_ch > res_ch: + out_ch = res_ch + if stride == 1: + in_ch_ = res_ch + channels.append((in_ch_, out_ch)) + out_ch *= 2 + + out_channels = [] + for _in_ch, _out_ch in channels: + out_channels.append(_out_ch) + out_channels.append(channels[-1][0]) + + layers = [] + in_ch_ = in_ch + for l in range(len(out_channels)): + out_ch_ = out_channels[l] + if l == 0 or l == 1: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) + else: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) + in_ch_ = out_ch_ + + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: + r"""The forward method of the `MaskConditionEncoder` class.""" + out = {} + for l in range(len(self.layers)): + layer = self.layers[l] + x = layer(x) + out[str(tuple(x.shape))] = x + x = torch.relu(x) + return out + + +class MaskConditionDecoder(nn.Module): + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # condition encoder + self.condition_encoder = MaskConditionEncoder( + in_ch=out_channels, + out_ch=block_out_channels[0], + res_ch=block_out_channels[-1], + ) + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + z: torch.Tensor, + image: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + latent_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + r"""The forward method of the `MaskConditionDecoder` class.""" + sample = z + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if torch.is_grad_enabled() and self.gradient_checkpointing: + # middle + sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self._gradient_checkpointing_func( + self.condition_encoder, + masked_image, + mask, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self.condition_encoder(masked_image, mask) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = up_block(sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, tuple]: + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q: torch.Tensor = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: tuple[int, ...]) -> torch.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.Tensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +class IdentityDistribution(object): + def __init__(self, parameters: torch.Tensor): + self.parameters = parameters + + def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: + return self.parameters + + def mode(self) -> torch.Tensor: + return self.parameters + + +class EncoderTiny(nn.Module): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: tuple[int, ...], + block_out_channels: tuple[int, ...], + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) + else: + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + ) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `EncoderTiny` class.""" + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(self.layers, x) + + else: + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) + + return x + + +class DecoderTiny(nn.Module): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: tuple[int, ...], + block_out_channels: tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + upsample_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), + get_activation(act_fn), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + ) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(self.layers, x) + else: + x = self.layers(x) + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) + + +class AutoencoderMixin: + def enable_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_tiling = True + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + if not hasattr(self, "use_slicing"): + raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/vq_model.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9214fb7faad6318f0fcb404a5bae8c2fe2229ed2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/autoencoders/vq_model.py @@ -0,0 +1,183 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The encoded output sample from the last layer of the model. + """ + + latents: torch.Tensor + + +class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + A VQ-VAE model for decoding latent representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + tuple of downsample block types. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + tuple of upsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(64,)`): + tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + norm_type (`str`, *optional*, defaults to `"group"`): + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. + """ + + _skip_layerwise_casting_patterns = ["quantize"] + _supports_group_offloading = False + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ("DownEncoderBlock2D",), + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: int | None = None, + scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + mid_block_add_attention=mid_block_add_attention, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, + ) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None + ) -> DecoderOutput | torch.Tensor: + # also go through quantization layer + if not force_not_quantize: + quant, commit_loss, _ = self.quantize(h) + elif self.config.lookup_from_codebook: + quant = self.quantize.get_codebook_entry(h, shape) + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) + else: + quant = h + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) + + if not return_dict: + return dec, commit_loss + + return DecoderOutput(sample=dec, commit_loss=commit_loss) + + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor, ...]: + r""" + The [`VQModel`] forward method. + + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a + plain `tuple` is returned. + """ + + h = self.encode(sample).latents + dec = self.decode(h) + + if not return_dict: + return dec.sample, dec.commit_loss + return dec diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17213d0b115ac61a1dc3e4064791201546238642 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80d9b98d5037221f5b437433aac2004b9f23cdf5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_cosmos.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_cosmos.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247fdf561afa2de049cc5a52fc1d2f3ab2ef5b37 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_cosmos.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_flax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_flax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0ca16be3669baa66a33b261c8dda527d8e0e85f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_flax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_flux.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_flux.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eebcb7a904773b160cd1e0fadc766cf59a156a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_flux.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_hunyuan.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_hunyuan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..738278afd6ef1f85338696c59dee87d3ef31e06b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_hunyuan.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_qwenimage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_qwenimage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31ec240560f11c253c0599c1d74dc9f6c93d5dc0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_qwenimage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sana.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sana.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e943d49fbfc1b8e14753ca7e35416e2caaafc35 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sana.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sd3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sd3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fde05b919c74e9e35ccce910956405ecc778d232 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sd3.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sparsectrl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sparsectrl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c3fa097d68f6ba07e045353ad17a80b0c31eef0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_sparsectrl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_union.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_union.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..864b5cd0ac0dd0f06075bf581e5d571fec258146 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_union.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_xs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_xs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..268d7af6037a9335f46abfcf21ad22aebfbf6296 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_xs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_z_image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_z_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1bfa074c5a7424079a1cb92f2533353a1e0776a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/controlnet_z_image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941aa23edb3adbd7b7e1a0021edb54b397e618b7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec4730403339da1683cde226f2ee2ded83d35c7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/controlnets/__pycache__/multicontrolnet_union.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef04fb62606ed30699be8ac6cb38b011a692070 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__init__.py @@ -0,0 +1,18 @@ +from ...utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .unet_1d import UNet1DModel + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .unet_3d_condition import UNet3DConditionModel + from .unet_i2vgen_xl import I2VGenXLUNet + from .unet_kandinsky3 import Kandinsky3UNet + from .unet_motion_model import MotionAdapter, UNetMotionModel + from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .unet_stable_cascade import StableCascadeUNet + from .uvit_2d import UVit2DModel + + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2d5212543fe659f606cc4c33c921ce2b1c50040 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f03b45d5aa52db6722f5e42152b0d7fcd4e931 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_1d_blocks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b990f7e4c39400e648600a249e8c89de63fb5267 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91ce5712919cb521bff0b68779f35ddc0a71b0cb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_2d_condition_flax.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30824df35be9f22fb9d90f388e8bd82bb2cfeb19 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_3d_blocks.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bee52763602659d7d13c16a8bcc3f2f90ef7a2a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_3d_condition.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_i2vgen_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_i2vgen_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0b30cb625b2ef4689e43284f33b8e21d6fd3b3a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_i2vgen_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_kandinsky3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_kandinsky3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c957dfef395585ec813392410fc22fe5ebfa3e9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_kandinsky3.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_motion_model.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_motion_model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92b03aae800e3c1dba8da4acf137b6628fd94e04 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_motion_model.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_stable_cascade.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_stable_cascade.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..871f6b2f12f4f600381db609e15b8e2bbd22acb8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/unet_stable_cascade.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/uvit_2d.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/uvit_2d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1674c9ead2852ca273443a12e8afc98f3abf5cce Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/__pycache__/uvit_2d.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..83ffe1f6f8cb6f0dcaed1de2c7514da768b474b7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py @@ -0,0 +1,265 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + The output of [`UNet1DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, sample_size)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.Tensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + extra_in_channels (`int`, *optional*, defaults to 0): + Number of additional channels to be added to the input of the first down block. Useful for cases where the + input data has more channels than what the model was initially designed for. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): + tuple of downsample block types. + up_block_types (`tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): + tuple of upsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(32, 32, 64)`): + tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. + act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. + layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. + downsample_each_block (`int`, *optional*, defaults to `False`): + Experimental feature for using a UNet without upsampling. + """ + + _skip_layerwise_casting_patterns = ["norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: int | None = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + time_embedding_dim: int | None = None, + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + freq_shift: float = 0.0, + down_block_types: tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: str = "UNetMidBlock1D", + out_block_type: str = None, + block_out_channels: tuple[int, ...] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, + ): + super().__init__() + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.out_block = None + + # down + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + return_dict: bool = True, + ) -> UNet1DOutput | tuple: + r""" + The [`UNet1DModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unets.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype)) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_1d_blocks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_1d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..a21521eaed4c30d0f07942e892b1dd3345659813 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_1d_blocks.py @@ -0,0 +1,701 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from ..activations import get_activation +from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + conv_shortcut: bool = False, + temb_channels: int = 32, + groups: int = 32, + groups_out: int | None = None, + non_linearity: str | None = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_downsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + temb_channels: int = 32, + groups: int = 32, + groups_out: int | None = None, + non_linearity: str | None = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...] | None = None, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, embed_dim: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + embed_dim: int, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity: str | None = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity is None: + self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + self.final_conv1d_act = get_activation(act_fn) + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + get_activation(act_fn), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Module): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Module): + def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Linear(self.channels, self.channels) + self.key = nn.Linear(self.channels, self.channels) + self.value = nn.Linear(self.channels, self.channels) + + self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) + + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = torch.softmax(attention_scores, dim=-1) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Module): + def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +class UNetMidBlock1D(nn.Module): + def __init__(self, mid_channels: int, in_channels: int, out_channels: int | None = None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int | None = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int | None = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Module): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int | None = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = torch.cat([hidden_states, temb], dim=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.ModuleList(resnets) + self.up = Upsample1d(kernel="cubic") + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Module): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int | None = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + + +DownBlockType = DownResnetBlock1D | DownBlock1D | AttnDownBlock1D | DownBlock1DNoSkip +MidBlockType = MidResTemporalBlock1D | ValueFunctionMidBlock1D | UNetMidBlock1D +OutBlockType = OutConv1DBlock | OutValueFunctionBlock +UpBlockType = UpResnetBlock1D | UpBlock1D | AttnUpBlock1D | UpBlock1DNoSkip + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, +) -> DownBlockType: + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool +) -> UpBlockType: + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + num_layers: int, + in_channels: int, + mid_channels: int, + out_channels: int, + embed_dim: int, + add_downsample: bool, +) -> MidBlockType: + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block( + *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int +) -> OutBlockType | None: + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) + return None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..4e54f757d120a370dd996289b2b076bc93ddbb41 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d.py @@ -0,0 +1,353 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block + + +@dataclass +class UNet2DOutput(BaseOutput): + """ + The output of [`UNet2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output from the last layer of the model. + """ + + sample: torch.Tensor + + +class UNet2DModel(ModelMixin, ConfigMixin): + r""" + A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - + 1)`. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): + tuple of downsample block types. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`. + up_block_types (`tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): + tuple of upsample block types. + block_out_channels (`tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): + tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + downsample_type (`str`, *optional*, defaults to `conv`): + The downsample type for downsampling layers. Choose between "conv" and "resnet" + upsample_type (`str`, *optional*, defaults to `conv`): + The upsample type for upsampling layers. Choose between "conv" and "resnet" + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + attn_norm_num_groups (`int`, *optional*, defaults to `None`): + If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the + given number of groups. If left as `None`, the group norm layer will only be created if + `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class + conditioning with `class_embed_type` equal to `None`. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + + @register_to_config + def __init__( + self, + sample_size: int | tuple[int, int] | None = None, + in_channels: int = 3, + out_channels: int = 3, + center_input_sample: bool = False, + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + freq_shift: int = 0, + flip_sin_to_cos: bool = True, + down_block_types: tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + mid_block_type: str | None = "UNetMidBlock2D", + up_block_types: tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), + block_out_channels: tuple[int, ...] = (224, 448, 672, 896), + layers_per_block: int = 2, + mid_block_scale_factor: float = 1, + downsample_padding: int = 1, + downsample_type: str = "conv", + upsample_type: str = "conv", + dropout: float = 0.0, + act_fn: str = "silu", + attention_head_dim: int | None = 8, + norm_num_groups: int = 32, + attn_norm_num_groups: int | None = None, + norm_eps: float = 1e-5, + resnet_time_scale_shift: str = "default", + add_attention: bool = True, + class_embed_type: str | None = None, + num_class_embeds: int | None = None, + num_train_timesteps: int | None = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + elif time_embedding_type == "learned": + self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type is None: + self.mid_block = None + else: + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, + add_attention=add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + dropout=dropout, + ) + self.up_blocks.append(up_block) + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + class_labels: torch.Tensor | None = None, + return_dict: bool = True, + ) -> UNet2DOutput | tuple: + r""" + The [`UNet2DModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unets.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when doing class conditioning") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + elif self.class_embedding is None and class_labels is not None: + raise ValueError("class_embedding needs to be initialized in order to use class conditioning") + + # 2. pre-process + skip_sample = sample + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "skip_conv"): + sample, res_samples, skip_sample = downsample_block( + hidden_states=sample, temb=emb, skip_sample=skip_sample + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block(sample, emb) + + # 5. up + skip_sample = None + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if hasattr(upsample_block, "skip_conv"): + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + else: + sample = upsample_block(sample, res_samples, emb) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if skip_sample is not None: + sample += skip_sample + + if self.config.time_embedding_type == "fourier": + timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) + sample = sample / timesteps + + if not return_dict: + return (sample,) + + return UNet2DOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..18a0b4ec5659089e5f72eb7a6ee588ff49d18673 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks.py @@ -0,0 +1,3583 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...utils import deprecate, logging +from ...utils.torch_utils import apply_freeu +from ..activations import get_activation +from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from ..normalization import AdaGroupNorm +from ..resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: int | None = None, + resnet_groups: int | None = None, + cross_attention_dim: int | None = None, + downsample_padding: int | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: str | None = None, + attention_head_dim: int | None = None, + downsample_type: str | None = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: int | None = None, + cross_attention_dim: int | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + cross_attention_norm: str | None = None, + attention_head_dim: int | None = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + return UNetMidBlock2DSimpleCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + return UNetMidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + num_layers=0, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + return None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: int | None = None, + transformer_layers_per_block: int = 1, + num_attention_heads: int | None = None, + resnet_groups: int | None = None, + cross_attention_dim: int | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: str | None = None, + attention_head_dim: int | None = None, + upsample_type: str | None = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class AutoencoderTinyBlock(nn.Module): + """ + Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU + blocks. + + Args: + in_channels (`int`): The number of input channels. + out_channels (`int`): The number of output channels. + act_fn (`str`): + ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. + + Returns: + `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to + `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`int | None`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: int | None = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: int | None = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: str | None = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == "conv": + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + additional_residuals: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None, *args, **kwargs + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + skip_sample: torch.Tensor | None = None, + *args, + **kwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...], torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + skip_sample: torch.Tensor | None = None, + *args, + **kwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...], torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None, *args, **kwargs + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: str | None = None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None, *args, **kwargs + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample: bool = True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + resnet, + hidden_states, + temb, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn(hidden_states) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: int | None = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: int | None = None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + skip_sample=None, + *args, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + skip_sample=None, + *args, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: str | None = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int | None = 32, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim: int = 1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + resnet, + hidden_states, + temb, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: str | None = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor: + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks_flax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6005afdc31b03075749477ae8545ca16143599 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks_flax.py @@ -0,0 +1,430 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + +from ...utils import logging +from ..attention_flax import FlaxTransformer2DModel +from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D + + +logger = logging.get_logger(__name__) + + +class FlaxCrossAttnDownBlock2D(nn.Module): + r""" + Cross Attention 2D Downsizing block - original architecture from Unet transformers: + https://huggingface.co/papers/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://huggingface.co/papers/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + add_downsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 + + def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxTransformer2DModel( + in_channels=self.out_channels, + n_heads=self.num_attention_heads, + d_head=self.out_channels // self.num_attention_heads, + depth=self.transformer_layers_per_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + r""" + Flax 2D downsizing block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + r""" + Cross Attention 2D Upsampling block - original architecture from Unet transformers: + https://huggingface.co/papers/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsampling layer before each final output + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://huggingface.co/papers/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + add_upsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 + + def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxTransformer2DModel( + in_channels=self.out_channels, + n_heads=self.num_attention_heads, + d_head=self.out_channels // self.num_attention_heads, + depth=self.transformer_layers_per_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + r""" + Flax 2D upsampling block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + prev_output_channel (:obj:`int`): + Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + r""" + Cross Attention 2D Mid-level block - original architecture from Unet transformers: + https://huggingface.co/papers/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + num_attention_heads (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + enable memory efficient attention https://huggingface.co/papers/2112.05682 + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + num_attention_heads: int = 1 + use_linear_projection: bool = False + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + dtype: jnp.dtype = jnp.float32 + transformer_layers_per_block: int = 1 + + def setup(self): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxTransformer2DModel( + in_channels=self.in_channels, + n_heads=self.num_attention_heads, + d_head=self.in_channels // self.num_attention_heads, + depth=self.transformer_layers_per_block, + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_condition.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..deae2589947552dc8944cba40fdb8b6dbc152017 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_condition.py @@ -0,0 +1,1237 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...utils import ( + BaseOutput, + apply_lora_scale, + deprecate, + logging, +) +from ..activations import get_activation +from ..attention import AttentionMixin +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ..modeling_utils import ModelMixin +from .unet_2d_blocks import ( + get_down_block, + get_mid_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNet2DConditionModel( + ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `tuple[int]`, or `tuple[tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`tuple[tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `tuple[tuple]` and for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"] + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["BasicTransformerBlock"] + + @register_to_config + def __init__( + self, + sample_size: int | tuple[int, int] | None = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: str | None = "UNetMidBlock2DCrossAttn", + up_block_types: tuple[str, ...] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: bool | tuple[bool] = False, + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int | tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int | tuple[int] = 1280, + transformer_layers_per_block: int | tuple[int] | tuple[tuple] = 1, + reverse_transformer_layers_per_block: tuple[tuple[int]] | None = None, + encoder_hid_dim: int | None = None, + encoder_hid_dim_type: str | None = None, + attention_head_dim: int | tuple[int] = 8, + num_attention_heads: int | tuple[int] | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + addition_embed_type: str | None = None, + addition_time_embed_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: bool | None = None, + cross_attention_norm: str | None = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: tuple[str, ...], + up_block_types: tuple[str, ...], + only_cross_attention: bool | tuple[bool], + block_out_channels: tuple[int, ...], + layers_per_block: int | tuple[int], + cross_attention_dim: int | tuple[int], + transformer_layers_per_block: int | tuple[int, tuple[tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: int | tuple[int] | None, + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: str | None, + cross_attention_dim: int | tuple[int], + encoder_hid_dim: int | None, + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: str | None, + act_fn: str, + num_class_embeds: int | None, + projection_class_embeddings_input_dim: int | None, + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: int | None, + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: int | None, + encoder_hid_dim: int | None, + projection_class_embeddings_input_dim: int | None, + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError( + f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." + ) + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size: str | int | list[int] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def get_time_embed(self, sample: torch.Tensor, timestep: torch.Tensor | float | int) -> torch.Tensor | None: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: torch.Tensor | None) -> torch.Tensor | None: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: dict[str, Any] + ) -> torch.Tensor | None: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: + encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) + + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + @apply_lora_scale("cross_attention_kwargs") + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + class_labels: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + down_intrablock_additional_residuals: tuple[torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + ) -> UNet2DConditionOutput | tuple: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_condition_flax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_condition_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..a361026fc0eac44953fa3b9c793ebb9e664edf49 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_condition_flax.py @@ -0,0 +1,461 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...utils import BaseOutput, logging +from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from ..modeling_flax_utils import FlaxModelMixin +from .unet_2d_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxCrossAttnUpBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, +) + + +logger = logging.get_logger(__name__) + + +@flax.struct.dataclass +class FlaxUNet2DConditionOutput(BaseOutput): + """ + The output of [`FlaxUNet2DConditionModel`]. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +@flax_register_to_config +class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its + general usage and behavior. + + Inherent JAX features such as the following are supported: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer + is skipped. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int` or `tuple[int]`, *optional*, defaults to 8): + The dimension of the attention heads. + num_attention_heads (`int` or `tuple[int]`, *optional*): + The number of attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): + Enable memory efficient attention as described [here](https://huggingface.co/papers/2112.05682). + split_head_dim (`bool`, *optional*, defaults to `False`): + Whether to split the head dimension into a new axis for the self-attention computation. In most cases, + enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL. + """ + + sample_size: int = 32 + in_channels: int = 4 + out_channels: int = 4 + down_block_types: tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + up_block_types: tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + mid_block_type: str | None = "UNetMidBlock2DCrossAttn" + only_cross_attention: bool | tuple[bool] = False + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: int | tuple[int, ...] = 8 + num_attention_heads: int | tuple[int, ...] | None = None + cross_attention_dim: int = 1280 + dropout: float = 0.0 + use_linear_projection: bool = False + dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True + freq_shift: int = 0 + use_memory_efficient_attention: bool = False + split_head_dim: bool = False + transformer_layers_per_block: int | tuple[int, ...] = 1 + addition_embed_type: str | None = None + addition_time_embed_dim: int | None = None + addition_embed_type_num_heads: int = 64 + projection_class_embeddings_input_dim: int | None = None + + def init_weights(self, rng: jax.Array) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + added_cond_kwargs = None + if self.addition_embed_type == "text_time": + # we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner + # or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim` + is_refiner = ( + 5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim + == self.config.projection_class_embeddings_input_dim + ) + num_micro_conditions = 5 if is_refiner else 6 + + text_embeds_dim = self.config.projection_class_embeddings_input_dim - ( + num_micro_conditions * self.config.addition_time_embed_dim + ) + + time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim + time_ids_dims = time_ids_channels // self.addition_time_embed_dim + added_cond_kwargs = { + "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32), + "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32), + } + return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"] + + def setup(self) -> None: + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) + + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + if self.num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = self.num_attention_heads or self.attention_head_dim + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(self.down_block_types) + + # transformer layers per block + transformer_layers_per_block = self.transformer_layers_per_block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types) + + # addition embed types + if self.addition_embed_type is None: + self.add_embedding = None + elif self.addition_embed_type == "text_time": + if self.addition_time_embed_dim is None: + raise ValueError( + f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None" + ) + self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift) + self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + else: + raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.") + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + num_attention_heads=num_attention_heads[i], + add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + if self.config.mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + num_attention_heads=num_attention_heads[-1], + transformer_layers_per_block=transformer_layers_per_block[-1], + use_linear_projection=self.use_linear_projection, + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + elif self.config.mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}") + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + for i, up_block_type in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + num_attention_heads=reversed_num_attention_heads[i], + add_upsample=not is_final_block, + dropout=self.dropout, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], + use_memory_efficient_attention=self.use_memory_efficient_attention, + split_head_dim=self.split_head_dim, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__( + self, + sample: jnp.ndarray, + timesteps: jnp.ndarray | float | int, + encoder_hidden_states: jnp.ndarray, + added_cond_kwargs: dict | FrozenDict | None = None, + down_block_additional_residuals: tuple[jnp.ndarray, ...] | None = None, + mid_block_additional_residual: jnp.ndarray | None = None, + return_dict: bool = True, + train: bool = False, + ) -> FlaxUNet2DConditionOutput | tuple[jnp.ndarray]: + r""" + Args: + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of + a plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # additional embeddings + aug_emb = None + if self.addition_embed_type == "text_time": + if added_cond_kwargs is None: + raise ValueError( + f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if text_embeds is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + if time_ids is None: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + # compute time embeds + time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256) + time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1)) + add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1) + aug_emb = self.add_embedding(add_embeds) + + t_emb = t_emb + aug_emb if aug_emb is not None else t_emb + + # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample += down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + deterministic=not train, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + sample = jnp.transpose(sample, (0, 3, 1, 2)) + + if not return_dict: + return (sample,) + + return FlaxUNet2DConditionOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_3d_blocks.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f5aae342b47421c750d63ede26b6c21b7acff4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_3d_blocks.py @@ -0,0 +1,1419 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch +from torch import nn + +from ...utils import deprecate, logging +from ...utils.torch_utils import apply_freeu +from ..attention import Attention +from ..resnet import ( + Downsample2D, + ResnetBlock2D, + SpatioTemporalResBlock, + TemporalConvLayer, + Upsample2D, +) +from ..transformers.transformer_2d import Transformer2DModel +from ..transformers.transformer_temporal import ( + TransformerSpatioTemporalModel, + TransformerTemporalModel, +) +from .unet_motion_model import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DownBlockMotion(DownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead." + deprecate("DownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead." + deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UpBlockMotion(UpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead." + deprecate("UpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead." + deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead." + deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resnet_groups: int | None = None, + cross_attention_dim: int | None = None, + downsample_padding: int | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int | tuple[int] = 1, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + dropout: float = 0.0, +) -> "DownBlock3D" | "CrossAttnDownBlock3D" | "DownBlockSpatioTemporal" | "CrossAttnDownBlockSpatioTemporal": + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + dropout=dropout, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + dropout=dropout, + ) + elif down_block_type == "DownBlockSpatioTemporal": + # added for SDV + return DownBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") + return CrossAttnDownBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) + + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resolution_idx: int | None = None, + resnet_groups: int | None = None, + cross_attention_dim: int | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_cross_attention_dim: int | None = None, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int | tuple[int] = 1, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + dropout: float = 0.0, +) -> "UpBlock3D" | "CrossAttnUpBlock3D" | "UpBlockSpatioTemporal" | "CrossAttnUpBlockSpatioTemporal": + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + dropout=dropout, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + dropout=dropout, + ) + elif up_block_type == "UpBlockSpatioTemporal": + # added for SDV + return UpBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") + return CrossAttnUpBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_upsample=add_upsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resolution_idx=resolution_idx, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + num_frames: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + num_frames: int = 1, + cross_attention_kwargs: dict[str, Any] = None, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + num_frames: int = 1, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resolution_idx: int | None = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + num_frames: int = 1, + cross_attention_kwargs: dict[str, Any] = None, + ) -> torch.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + resolution_idx: int | None = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + num_frames: int = 1, + ) -> torch.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class MidBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + attention_head_dim: int = 512, + num_layers: int = 1, + upcast_attention: bool = False, + ): + super().__init__() + + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=1e-6, + upcast_attention=upcast_attention, + norm_num_groups=32, + bias=True, + residual_connection=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + image_only_indicator: torch.Tensor, + ): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = attn(hidden_states) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class UpBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.Tensor, + image_only_indicator: torch.Tensor, + ) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + image_only_indicator: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0]( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) + + return hidden_states + + +class DownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + image_only_indicator: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states = () + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) + else: + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-6, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=1, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + image_only_indicator: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + for resnet, attn in blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + num_layers: int = 1, + resnet_eps: float = 1e-6, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + image_only_indicator: torch.Tensor | None = None, + upsample_size: int | None = None, + ) -> torch.Tensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) + else: + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class CrossAttnUpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + image_only_indicator: torch.Tensor | None = None, + upsample_size: int | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, image_only_indicator) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, image_only_indicator=image_only_indicator) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_3d_condition.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_3d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..5006e48feb46777f380cbc3c759b8bb71b14d3f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_3d_condition.py @@ -0,0 +1,675 @@ +# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2025 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..activations import get_activation +from ..attention import AttentionMixin +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + UNetMidBlock3DCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + +class UNet3DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): + The tuple of downsample blocks to use. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`): + The tuple of upsample blocks to use. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): The number of attention heads. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + """ + + _supports_gradient_checkpointing = False + _skip_layerwise_casting_patterns = ["norm", "time_embedding"] + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: tuple[str, ...] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: int | tuple[int] = 64, + num_attention_heads: int | tuple[int] | None = None, + time_cond_proj_dim: int | None = None, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise NotImplementedError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + cond_proj_dim=time_cond_proj_dim, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + norm_num_groups=norm_num_groups, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = get_activation("silu") + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: str | int | list[int]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def enable_forward_chunking(self, chunk_size: int | None = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + class_labels: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + return_dict: bool = True, + ) -> UNet3DConditionOutput | tuple[torch.Tensor]: + r""" + The [`UNet3DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + + Returns: + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames + ) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_i2vgen_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_i2vgen_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3cfe91d5bd27fc81846fdc34913fbe5f88e939 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_i2vgen_xl.py @@ -0,0 +1,650 @@ +# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..activations import get_activation +from ..attention import Attention, AttentionMixin, FeedForward +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + UNetMidBlock3DCrossAttn, + get_down_block, + get_up_block, +) +from .unet_3d_condition import UNet3DConditionOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class I2VGenXLTransformerTemporalEncoder(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "geglu", + upcast_attention: bool = False, + ff_inner_dim: int | None = None, + dropout: int = 0.0, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=True, + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + ff_output = self.ff(hidden_states) + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class I2VGenXLUNet(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and + returns a sample-shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim. + num_attention_heads (`int`, *optional*): The number of attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: tuple[str, ...] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + norm_num_groups: int | None = 32, + cross_attention_dim: int = 1024, + attention_head_dim: int | tuple[int] = 64, + num_attention_heads: int | tuple[int] | None = None, + ): + super().__init__() + + # When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence + # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This + # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below. + # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it + # without running proper deprecation cycles for the {down,mid,up} blocks which are a + # part of the public API. + num_attention_heads = attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d(in_channels + in_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=num_attention_heads, + in_channels=block_out_channels[0], + num_layers=1, + norm_num_groups=norm_num_groups, + ) + + # image embedding + self.image_latents_proj_in = nn.Sequential( + nn.Conv2d(4, in_channels * 4, 3, padding=1), + nn.SiLU(), + nn.Conv2d(in_channels * 4, in_channels * 4, 3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(in_channels * 4, in_channels, 3, stride=1, padding=1), + ) + self.image_latents_temporal_encoder = I2VGenXLTransformerTemporalEncoder( + dim=in_channels, + num_attention_heads=2, + ff_inner_dim=in_channels * 4, + attention_head_dim=in_channels, + activation_fn="gelu", + ) + self.image_latents_context_embedding = nn.Sequential( + nn.Conv2d(4, in_channels * 8, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((32, 32)), + nn.Conv2d(in_channels * 8, in_channels * 16, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(in_channels * 16, cross_attention_dim, 3, stride=2, padding=1), + ) + + # other embeddings -- time, context, fps, etc. + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn="silu") + self.context_embedding = nn.Sequential( + nn.Linear(cross_attention_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, cross_attention_dim * in_channels), + ) + self.fps_embedding = nn.Sequential( + nn.Linear(timestep_input_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # blocks + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-05, + resnet_act_fn="silu", + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=1, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=1e-05, + resnet_act_fn="silu", + output_scale_factor=1, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-05, + resnet_act_fn="silu", + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-05) + self.conv_act = get_activation("silu") + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: int | None = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + fps: torch.Tensor, + image_latents: torch.Tensor, + image_embeddings: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> UNet3DConditionOutput | tuple[torch.Tensor]: + r""" + The [`I2VGenXLUNet`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". + image_latents (`torch.Tensor`): Image encodings from the VAE. + image_embeddings (`torch.Tensor`): + Projection embeddings of the conditioning image computed with a vision encoder. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_3d_condition.UNet3DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + batch_size, channels, num_frames, height, width = sample.shape + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timesteps, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + t_emb = self.time_embedding(t_emb, timestep_cond) + + # 2. FPS + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + fps = fps.expand(fps.shape[0]) + fps_emb = self.fps_embedding(self.time_proj(fps).to(dtype=self.dtype)) + + # 3. time + FPS embeddings. + emb = t_emb + fps_emb + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + + # 4. context embeddings. + # The context embeddings consist of both text embeddings from the input prompt + # AND the image embeddings from the input image. For images, both VAE encodings + # and the CLIP image embeddings are incorporated. + # So the final `context_embeddings` becomes the query for cross-attention. + context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim) + context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1) + + image_latents_for_context_embds = image_latents[:, :, :1, :] + image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape( + image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], + image_latents_for_context_embds.shape[1], + image_latents_for_context_embds.shape[3], + image_latents_for_context_embds.shape[4], + ) + image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) + + _batch_size, _channels, _height, _width = image_latents_context_embs.shape + image_latents_context_embs = image_latents_context_embs.permute(0, 2, 3, 1).reshape( + _batch_size, _height * _width, _channels + ) + context_emb = torch.cat([context_emb, image_latents_context_embs], dim=1) + + image_emb = self.context_embedding(image_embeddings) + image_emb = image_emb.view(-1, self.config.in_channels, self.config.cross_attention_dim) + context_emb = torch.cat([context_emb, image_emb], dim=1) + context_emb = context_emb.repeat_interleave(num_frames, dim=0, output_size=context_emb.shape[0] * num_frames) + + image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( + image_latents.shape[0] * image_latents.shape[2], + image_latents.shape[1], + image_latents.shape[3], + image_latents.shape[4], + ) + image_latents = self.image_latents_proj_in(image_latents) + image_latents = ( + image_latents[None, :] + .reshape(batch_size, num_frames, channels, height, width) + .permute(0, 3, 4, 1, 2) + .reshape(batch_size * height * width, num_frames, channels) + ) + image_latents = self.image_latents_temporal_encoder(image_latents) + image_latents = image_latents.reshape(batch_size, height, width, num_frames, channels).permute(0, 4, 3, 1, 2) + + # 5. pre-process + sample = torch.cat([sample, image_latents], dim=1) + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 6. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=context_emb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 7. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=context_emb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + # 8. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=context_emb, + upsample_size=upsample_size, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 9. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_kandinsky3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_kandinsky3.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa68b42ee30a20bd769a31c88c9f2ea86799085 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_kandinsky3.py @@ -0,0 +1,472 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..attention import AttentionMixin +from ..attention_processor import Attention, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Kandinsky3UNetOutput(BaseOutput): + sample: torch.Tensor = None + + +class Kandinsky3EncoderProj(nn.Module): + def __init__(self, encoder_hid_dim, cross_attention_dim): + super().__init__() + self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) + self.projection_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, x): + x = self.projection_linear(x) + x = self.projection_norm(x) + return x + + +class Kandinsky3UNet(ModelMixin, AttentionMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 4, + time_embedding_dim: int = 1536, + groups: int = 32, + attention_head_dim: int = 64, + layers_per_block: int | tuple[int] = 3, + block_out_channels: tuple[int, ...] = (384, 768, 1536, 3072), + cross_attention_dim: int | tuple[int] = 4096, + encoder_hid_dim: int = 4096, + ): + super().__init__() + + # TODO(Yiyi): Give better name and put into config for the following 4 parameters + expansion_ratio = 4 + compression_ratio = 2 + add_cross_attention = (False, True, True, True) + add_self_attention = (False, True, True, True) + + out_channels = in_channels + init_channels = block_out_channels[0] // 2 + self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) + + self.time_embedding = TimestepEmbedding( + init_channels, + time_embedding_dim, + ) + + self.add_time_condition = Kandinsky3AttentionPooling( + time_embedding_dim, cross_attention_dim, attention_head_dim + ) + + self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1) + + self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim) + + hidden_dims = [init_channels] + list(block_out_channels) + in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) + text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] + num_blocks = len(block_out_channels) * [layers_per_block] + layer_params = [num_blocks, text_dims, add_self_attention] + rev_layer_params = map(reversed, layer_params) + + cat_dims = [] + self.num_levels = len(in_out_dims) + self.down_blocks = nn.ModuleList([]) + for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(in_out_dims, *layer_params) + ): + down_sample = level != (self.num_levels - 1) + cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) + self.down_blocks.append( + Kandinsky3DownSampleBlock( + in_dim, + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + down_sample, + self_attention, + ) + ) + + self.up_blocks = nn.ModuleList([]) + for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(reversed(in_out_dims), *rev_layer_params) + ): + up_sample = level != 0 + self.up_blocks.append( + Kandinsky3UpSampleBlock( + in_dim, + cat_dims.pop(), + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + up_sample, + self_attention, + ) + ) + + self.conv_norm_out = nn.GroupNorm(groups, init_channels) + self.conv_act_out = nn.SiLU() + self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if not torch.is_tensor(timestep): + dtype = torch.float32 if isinstance(timestep, float) else torch.int32 + timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) + elif len(timestep.shape) == 0: + timestep = timestep[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = timestep.expand(sample.shape[0]) + time_embed_input = self.time_proj(timestep).to(sample.dtype) + time_embed = self.time_embedding(time_embed_input) + + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + + if encoder_hidden_states is not None: + time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) + + hidden_states = [] + sample = self.conv_in(sample) + for level, down_sample in enumerate(self.down_blocks): + sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) + if level != self.num_levels - 1: + hidden_states.append(sample) + + for level, up_sample in enumerate(self.up_blocks): + if level != 0: + sample = torch.cat([sample, hidden_states.pop()], dim=1) + sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) + + sample = self.conv_norm_out(sample) + sample = self.conv_act_out(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + return Kandinsky3UNetOutput(sample=sample) + + +class Kandinsky3UpSampleBlock(nn.Module): + def __init__( + self, + in_channels, + cat_dim, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + up_sample=True, + self_attention=True, + ): + super().__init__() + up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) + hidden_channels = ( + [(in_channels + cat_dim, in_channels)] + + [(in_channels, in_channels)] * (num_blocks - 2) + + [(in_channels, out_channels)] + ) + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) + ) + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) + ) + else: + attentions.append(nn.Identity()) + + resnets_out.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets_in = nn.ModuleList(resnets_in) + self.resnets_out = nn.ModuleList(resnets_out) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + return x + + +class Kandinsky3DownSampleBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + down_sample=True, + self_attention=True, + ): + super().__init__() + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] + hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) + ) + else: + attentions.append(nn.Identity()) + + resnets_out.append( + Kandinsky3ResNetBlock( + out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets_in = nn.ModuleList(resnets_in) + self.resnets_out = nn.ModuleList(resnets_out) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + return x + + +class Kandinsky3ConditionalGroupNorm(nn.Module): + def __init__(self, groups, normalized_shape, context_dim): + super().__init__() + self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) + self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape)) + self.context_mlp[1].weight.data.zero_() + self.context_mlp[1].bias.data.zero_() + + def forward(self, x, context): + context = self.context_mlp(context) + + for _ in range(len(x.shape[2:])): + context = context.unsqueeze(-1) + + scale, shift = context.chunk(2, dim=1) + x = self.norm(x) * (scale + 1.0) + shift + return x + + +class Kandinsky3Block(nn.Module): + def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): + super().__init__() + self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) + self.activation = nn.SiLU() + if up_resolution is not None and up_resolution: + self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + else: + self.up_sample = nn.Identity() + + padding = int(kernel_size > 1) + self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) + + if up_resolution is not None and not up_resolution: + self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + else: + self.down_sample = nn.Identity() + + def forward(self, x, time_embed): + x = self.group_norm(x, time_embed) + x = self.activation(x) + x = self.up_sample(x) + x = self.projection(x) + x = self.down_sample(x) + return x + + +class Kandinsky3ResNetBlock(nn.Module): + def __init__( + self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None] + ): + super().__init__() + kernel_sizes = [1, 3, 3, 1] + hidden_channel = max(in_channels, out_channels) // compression_ratio + hidden_channels = ( + [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)] + ) + self.resnet_blocks = nn.ModuleList( + [ + Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution) + for (in_channel, out_channel), kernel_size, up_resolution in zip( + hidden_channels, kernel_sizes, up_resolutions + ) + ] + ) + self.shortcut_up_sample = ( + nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2) + if True in up_resolutions + else nn.Identity() + ) + self.shortcut_projection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() + ) + self.shortcut_down_sample = ( + nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2) + if False in up_resolutions + else nn.Identity() + ) + + def forward(self, x, time_embed): + out = x + for resnet_block in self.resnet_blocks: + out = resnet_block(out, time_embed) + + x = self.shortcut_up_sample(x) + x = self.shortcut_projection(x) + x = self.shortcut_down_sample(x) + x = x + out + return x + + +class Kandinsky3AttentionPooling(nn.Module): + def __init__(self, num_channels, context_dim, head_dim=64): + super().__init__() + self.attention = Attention( + context_dim, + context_dim, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) + + def forward(self, x, context, context_mask=None): + context_mask = context_mask.to(dtype=context.dtype) + context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) + return x + context.squeeze(1) + + +class Kandinsky3AttentionBlock(nn.Module): + def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): + super().__init__() + self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.attention = Attention( + num_channels, + context_dim or num_channels, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) + + hidden_channels = expansion_ratio * num_channels + self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.feed_forward = nn.Sequential( + nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False), + nn.SiLU(), + nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False), + ) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + height, width = x.shape[-2:] + out = self.in_norm(x, time_embed) + out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) + context = context if context is not None else out + if context_mask is not None: + context_mask = context_mask.to(dtype=context.dtype) + + out = self.attention(out, context, context_mask) + out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) + x = x + out + + out = self.out_norm(x, time_embed) + out = self.feed_forward(out) + x = x + out + return x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_motion_model.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_motion_model.py new file mode 100644 index 0000000000000000000000000000000000000000..97452eff05aa761713f57342fb518798eb12b042 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_motion_model.py @@ -0,0 +1,2109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import BaseOutput, apply_lora_scale, deprecate, logging +from ...utils.torch_utils import apply_freeu +from ..attention import AttentionMixin, BasicTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttnAddedKVProcessor, + AttnProcessor, + AttnProcessor2_0, + FusedAttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel +from .unet_2d_blocks import UNetMidBlock2DCrossAttn +from .unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetMotionOutput(BaseOutput): + """ + The output of [`UNetMotionOutput`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor + + +class AnimateDiffTransformer3D(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int | None = None, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: int | None = None, + attention_bias: bool = False, + sample_size: int | None = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: str | None = None, + num_positional_embeddings: int | None = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.LongTensor | None = None, + timestep: torch.LongTensor | None = None, + class_labels: torch.LongTensor | None = None, + num_frames: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor: + """ + The [`AnimateDiffTransformer3D`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Returns: + torch.Tensor: + The output tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(input=hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(input=hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + return output + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: int | tuple[int] = 1, + temporal_cross_attention_dim: int | None = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + temporal_double_self_attention: bool = True, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}" + ) + + # support for variable number of attention head per temporal layers + if isinstance(temporal_num_attention_heads, int): + temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers + elif len(temporal_num_attention_heads) != num_layers: + raise ValueError( + f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads[i], + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads[i], + double_self_attention=temporal_double_self_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + num_frames: int = 1, + *args, + **kwargs, + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states=hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: int | None = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + temporal_double_self_attention: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + double_self_attention=temporal_double_self_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + num_frames: int = 1, + encoder_attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + additional_residuals: torch.Tensor | None = None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states=hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: int | None = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}" + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_cross_attention_dim: int | None = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + ): + super().__init__() + resnets = [] + motion_modules = [] + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}" + ) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size=None, + num_frames: int = 1, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: int | None = None, + temporal_max_seq_length: int = 32, + temporal_transformer_layers_per_block: int | tuple[int] = 1, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * num_layers + elif len(transformer_layers_per_block) != num_layers: + raise ValueError( + f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # support for variable transformer layers per temporal block + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers + elif len(temporal_transformer_layers_per_block) != num_layers: + raise ValueError( + f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}." + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + AnimateDiffTransformer3D( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + num_layers=temporal_transformer_layers_per_block[i], + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + num_frames: int = 1, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + motion_module, hidden_states, None, None, None, num_frames, None + ) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = motion_module(hidden_states, None, None, None, num_frames, None) + hidden_states = resnet(input_tensor=hidden_states, temb=temb) + + return hidden_states + + +class MotionModules(nn.Module): + def __init__( + self, + in_channels: int, + layers_per_block: int = 2, + transformer_layers_per_block: int | tuple[int] = 8, + num_attention_heads: int | tuple[int] = 8, + attention_bias: bool = False, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + norm_num_groups: int = 32, + max_seq_length: int = 32, + ): + super().__init__() + self.motion_modules = nn.ModuleList([]) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block + elif len(transformer_layers_per_block) != layers_per_block: + raise ValueError( + f"The number of transformer layers per block must match the number of layers per block, " + f"got {layers_per_block} and {len(transformer_layers_per_block)}" + ) + + for i in range(layers_per_block): + self.motion_modules.append( + AnimateDiffTransformer3D( + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + positional_embeddings="sinusoidal", + num_positional_embeddings=max_seq_length, + ) + ) + + +class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), + motion_layers_per_block: int | tuple[int] = 2, + motion_transformer_layers_per_block: int | tuple[int] | tuple[tuple[int]] = 1, + motion_mid_block_layers_per_block: int = 1, + motion_transformer_layers_per_mid_block: int | tuple[int] = 1, + motion_num_attention_heads: int | tuple[int] = 8, + motion_norm_num_groups: int = 32, + motion_max_seq_length: int = 32, + use_motion_mid_block: bool = True, + conv_in_channels: int | None = None, + ): + """Container to store AnimateDiff Motion Modules + + Args: + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each UNet block. + motion_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 2): + The number of motion layers per UNet block. + motion_transformer_layers_per_block (`int`, `tuple[int]`, or `tuple[tuple[int]]`, *optional*, defaults to 1): + The number of transformer layers to use in each motion layer in each block. + motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): + The number of motion layers in the middle UNet block. + motion_transformer_layers_per_mid_block (`int` or `tuple[int]`, *optional*, defaults to 1): + The number of transformer layers to use in each motion layer in the middle block. + motion_num_attention_heads (`int` or `tuple[int]`, *optional*, defaults to 8): + The number of heads to use in each attention layer of the motion module. + motion_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use in each group normalization layer of the motion module. + motion_max_seq_length (`int`, *optional*, defaults to 32): + The maximum sequence length to use in the motion module. + use_motion_mid_block (`bool`, *optional*, defaults to True): + Whether to use a motion module in the middle of the UNet. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + if isinstance(motion_layers_per_block, int): + motion_layers_per_block = (motion_layers_per_block,) * len(block_out_channels) + elif len(motion_layers_per_block) != len(block_out_channels): + raise ValueError( + f"The number of motion layers per block must match the number of blocks, " + f"got {len(block_out_channels)} and {len(motion_layers_per_block)}" + ) + + if isinstance(motion_transformer_layers_per_block, int): + motion_transformer_layers_per_block = (motion_transformer_layers_per_block,) * len(block_out_channels) + + if isinstance(motion_transformer_layers_per_mid_block, int): + motion_transformer_layers_per_mid_block = ( + motion_transformer_layers_per_mid_block, + ) * motion_mid_block_layers_per_block + elif len(motion_transformer_layers_per_mid_block) != motion_mid_block_layers_per_block: + raise ValueError( + f"The number of layers per mid block ({motion_mid_block_layers_per_block}) " + f"must match the length of motion_transformer_layers_per_mid_block ({len(motion_transformer_layers_per_mid_block)})" + ) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(block_out_channels) + elif len(motion_num_attention_heads) != len(block_out_channels): + raise ValueError( + f"The length of the attention head number tuple in the motion module must match the " + f"number of block, got {len(motion_num_attention_heads)} and {len(block_out_channels)}" + ) + + if conv_in_channels: + # input + self.conv_in = nn.Conv2d(conv_in_channels, block_out_channels[0], kernel_size=3, padding=1) + else: + self.conv_in = None + + for i, channel in enumerate(block_out_channels): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads[i], + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block[i], + transformer_layers_per_block=motion_transformer_layers_per_block[i], + ) + ) + + if use_motion_mid_block: + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads[-1], + max_seq_length=motion_max_seq_length, + layers_per_block=motion_mid_block_layers_per_block, + transformer_layers_per_block=motion_transformer_layers_per_mid_block, + ) + else: + self.mid_block = None + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + + reversed_motion_layers_per_block = list(reversed(motion_layers_per_block)) + reversed_motion_transformer_layers_per_block = list(reversed(motion_transformer_layers_per_block)) + reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) + for i, channel in enumerate(reversed_block_out_channels): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=reversed_motion_num_attention_heads[i], + max_seq_length=motion_max_seq_length, + layers_per_block=reversed_motion_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_motion_transformer_layers_per_block[i], + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + def forward(self, sample): + pass + + +class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a + sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + up_block_types: tuple[str, ...] = ( + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + ), + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int | tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: int | tuple[int] | tuple[tuple] = 1, + reverse_transformer_layers_per_block: int | tuple[int] | tuple[tuple] | None = None, + temporal_transformer_layers_per_block: int | tuple[int] | tuple[tuple] = 1, + reverse_temporal_transformer_layers_per_block: int | tuple[int] | tuple[tuple] | None = None, + transformer_layers_per_mid_block: int | tuple[int] | None = None, + temporal_transformer_layers_per_mid_block: int | tuple[int] | None = 1, + use_linear_projection: bool = False, + num_attention_heads: int | tuple[int, ...] = 8, + motion_max_seq_length: int = 32, + motion_num_attention_heads: int | tuple[int, ...] = 8, + reverse_motion_num_attention_heads: int | tuple[int, ...] | tuple[tuple[int, ...], ...] | None = None, + use_motion_mid_block: bool = True, + mid_block_layers: int = 1, + encoder_hid_dim: int | None = None, + encoder_hid_dim_type: str | None = None, + addition_embed_type: str | None = None, + addition_time_embed_dim: int | None = None, + projection_class_embeddings_input_dim: int | None = None, + time_cond_proj_dim: int | None = None, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + if ( + isinstance(temporal_transformer_layers_per_block, list) + and reverse_temporal_transformer_layers_per_block is None + ): + for layer_number_per_block in temporal_transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError( + "Must provide 'reverse_temporal_transformer_layers_per_block` if using asymmetrical motion module in UNet." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim + ) + + if encoder_hid_dim_type is None: + self.encoder_hid_proj = None + + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, True, 0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if isinstance(reverse_transformer_layers_per_block, int): + reverse_transformer_layers_per_block = [reverse_transformer_layers_per_block] * len(down_block_types) + + if isinstance(temporal_transformer_layers_per_block, int): + temporal_transformer_layers_per_block = [temporal_transformer_layers_per_block] * len(down_block_types) + + if isinstance(reverse_temporal_transformer_layers_per_block, int): + reverse_temporal_transformer_layers_per_block = [reverse_temporal_transformer_layers_per_block] * len( + down_block_types + ) + + if isinstance(motion_num_attention_heads, int): + motion_num_attention_heads = (motion_num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlockMotion": + down_block = CrossAttnDownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=num_attention_heads[i], + cross_attention_dim=cross_attention_dim[i], + downsample_padding=downsample_padding, + add_downsample=not is_final_block, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + elif down_block_type == "DownBlockMotion": + down_block = DownBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_layers=layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + downsample_padding=downsample_padding, + temporal_num_attention_heads=motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `down_block_type` encountered. Must be one of `CrossAttnDownBlockMotion` or `DownBlockMotion`" + ) + + self.down_blocks.append(down_block) + + # mid + if transformer_layers_per_mid_block is None: + transformer_layers_per_mid_block = ( + transformer_layers_per_block[-1] if isinstance(transformer_layers_per_block[-1], int) else 1 + ) + + if use_motion_mid_block: + self.mid_block = UNetMidBlockCrossAttnMotion( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + num_layers=mid_block_layers, + temporal_num_attention_heads=motion_num_attention_heads[-1], + temporal_max_seq_length=motion_max_seq_length, + transformer_layers_per_block=transformer_layers_per_mid_block, + temporal_transformer_layers_per_block=temporal_transformer_layers_per_mid_block, + ) + + else: + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + num_layers=mid_block_layers, + transformer_layers_per_block=transformer_layers_per_mid_block, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_motion_num_attention_heads = list(reversed(motion_num_attention_heads)) + + if reverse_transformer_layers_per_block is None: + reverse_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + if reverse_temporal_transformer_layers_per_block is None: + reverse_temporal_transformer_layers_per_block = list(reversed(temporal_transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + if up_block_type == "CrossAttnUpBlockMotion": + up_block = CrossAttnUpBlockMotion( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reverse_transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + num_attention_heads=reversed_num_attention_heads[i], + cross_attention_dim=reversed_cross_attention_dim[i], + add_upsample=add_upsample, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + elif up_block_type == "UpBlockMotion": + up_block = UpBlockMotion( + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + resolution_idx=i, + num_layers=reversed_layers_per_block[i] + 1, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=add_upsample, + temporal_num_attention_heads=reversed_motion_num_attention_heads[i], + temporal_max_seq_length=motion_max_seq_length, + temporal_transformer_layers_per_block=reverse_temporal_transformer_layers_per_block[i], + ) + else: + raise ValueError( + "Invalid `up_block_type` encountered. Must be one of `CrossAttnUpBlockMotion` or `UpBlockMotion`" + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @classmethod + def from_unet2d( + cls, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter | None = None, + load_weights: bool = True, + ): + has_motion_adapter = motion_adapter is not None + + if has_motion_adapter: + motion_adapter.to(device=unet.device) + + # check compatibility of number of blocks + if len(unet.config["down_block_types"]) != len(motion_adapter.config["block_out_channels"]): + raise ValueError("Incompatible Motion Adapter, got different number of blocks") + + # check layers compatibility for each block + if isinstance(unet.config["layers_per_block"], int): + expanded_layers_per_block = [unet.config["layers_per_block"]] * len(unet.config["down_block_types"]) + else: + expanded_layers_per_block = list(unet.config["layers_per_block"]) + if isinstance(motion_adapter.config["motion_layers_per_block"], int): + expanded_adapter_layers_per_block = [motion_adapter.config["motion_layers_per_block"]] * len( + motion_adapter.config["block_out_channels"] + ) + else: + expanded_adapter_layers_per_block = list(motion_adapter.config["motion_layers_per_block"]) + if expanded_layers_per_block != expanded_adapter_layers_per_block: + raise ValueError("Incompatible Motion Adapter, got different number of layers per block") + + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 + config = dict(unet.config) + config["_class_name"] = cls.__name__ + + down_blocks = [] + for down_blocks_type in config["down_block_types"]: + if "CrossAttn" in down_blocks_type: + down_blocks.append("CrossAttnDownBlockMotion") + else: + down_blocks.append("DownBlockMotion") + config["down_block_types"] = down_blocks + + up_blocks = [] + for down_blocks_type in config["up_block_types"]: + if "CrossAttn" in down_blocks_type: + up_blocks.append("CrossAttnUpBlockMotion") + else: + up_blocks.append("UpBlockMotion") + config["up_block_types"] = up_blocks + + if has_motion_adapter: + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] + config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + config["layers_per_block"] = motion_adapter.config["motion_layers_per_block"] + config["temporal_transformer_layers_per_mid_block"] = motion_adapter.config[ + "motion_transformer_layers_per_mid_block" + ] + config["temporal_transformer_layers_per_block"] = motion_adapter.config[ + "motion_transformer_layers_per_block" + ] + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + + # For PIA UNets we need to set the number input channels to 9 + if motion_adapter.config["conv_in_channels"]: + config["in_channels"] = motion_adapter.config["conv_in_channels"] + + # Need this for backwards compatibility with UNet2DConditionModel checkpoints + if not config.get("num_attention_heads"): + config["num_attention_heads"] = config["attention_head_dim"] + + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs}) + config["_class_name"] = cls.__name__ + model = cls.from_config(config) + + if not load_weights: + return model + + # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight + # while the last 5 channels must be PIA conv_in weights. + if has_motion_adapter and motion_adapter.config["conv_in_channels"]: + model.conv_in = motion_adapter.conv_in + updated_conv_in_weight = torch.cat( + [unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], dim=1 + ) + model.conv_in.load_state_dict({"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) + else: + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + + model.time_proj.load_state_dict(unet.time_proj.state_dict()) + model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if any( + isinstance(proc, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + for proc in unet.attn_processors.values() + ): + attn_procs = {} + for name, processor in unet.attn_processors.items(): + if name.endswith("attn1.processor"): + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=processor.hidden_size, + cross_attention_dim=processor.cross_attention_dim, + scale=processor.scale, + num_tokens=processor.num_tokens, + ) + for name, processor in model.attn_processors.items(): + if name not in attn_procs: + attn_procs[name] = processor.__class__() + model.set_attn_processor(attn_procs) + model.config.encoder_hid_dim_type = "ip_image_proj" + model.encoder_hid_proj = unet.encoder_hid_proj + + for i, down_block in enumerate(unet.down_blocks): + model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) + if hasattr(model.down_blocks[i], "attentions"): + model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) + if model.down_blocks[i].downsamplers: + model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) + + for i, up_block in enumerate(unet.up_blocks): + model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) + if hasattr(model.up_blocks[i], "attentions"): + model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) + if model.up_blocks[i].upsamplers: + model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) + + model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + if unet.conv_norm_out is not None: + model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + if unet.conv_act is not None: + model.conv_act.load_state_dict(unet.conv_act.state_dict()) + model.conv_out.load_state_dict(unet.conv_out.state_dict()) + + if has_motion_adapter: + model.load_motion_modules(motion_adapter) + + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet2d_params(self) -> None: + """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules + unfrozen for fine tuning. + """ + # Freeze everything + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze Motion Modules + for down_block in self.down_blocks: + motion_modules = down_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + for up_block in self.up_blocks: + motion_modules = up_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + if hasattr(self.mid_block, "motion_modules"): + motion_modules = self.mid_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + def load_motion_modules(self, motion_adapter: MotionAdapter | None) -> None: + for i, down_block in enumerate(motion_adapter.down_blocks): + self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) + for i, up_block in enumerate(motion_adapter.up_blocks): + self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) + + # to support older motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) + + def save_motion_modules( + self, + save_directory: str, + is_main_process: bool = True, + safe_serialization: bool = True, + variant: str | None = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + state_dict = self.state_dict() + + # Extract all motion modules + motion_state_dict = {} + for k, v in state_dict.items(): + if "motion_modules" in k: + motion_state_dict[k] = v + + adapter = MotionAdapter( + block_out_channels=self.config["block_out_channels"], + motion_layers_per_block=self.config["layers_per_block"], + motion_norm_num_groups=self.config["norm_num_groups"], + motion_num_attention_heads=self.config["motion_num_attention_heads"], + motion_max_seq_length=self.config["motion_max_seq_length"], + use_motion_mid_block=self.config["use_motion_mid_block"], + ) + adapter.load_state_dict(motion_state_dict) + adapter.save_pretrained( + save_directory=save_directory, + is_main_process=is_main_process, + safe_serialization=safe_serialization, + variant=variant, + push_to_hub=push_to_hub, + **kwargs, + ) + + def enable_forward_chunking(self, chunk_size: int | None = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self) -> None: + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self) -> None: + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: + r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self) -> None: + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + @apply_lora_scale("cross_attention_kwargs") + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + return_dict: bool = True, + ) -> UNetMotionOutput | tuple[torch.Tensor]: + r""" + The [`UNetMotionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_motion_model.UNetMotionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_motion_model.UNetMotionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_motion_model.UNetMotionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb if aug_emb is None else emb + aug_emb + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + image_embeds = [ + image_embed.repeat_interleave(num_frames, dim=0, output_size=image_embed.shape[0] * num_frames) + for image_embed in image_embeds + ] + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + # To support older versions of motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNetMotionOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_spatio_temporal_condition.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_spatio_temporal_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..eddeb9826b0c35c05ed0e118e5db55d8c8716ba0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -0,0 +1,450 @@ +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..attention import AttentionMixin +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetSpatioTemporalConditionOutput(BaseOutput): + """ + The output of [`UNetSpatioTemporalConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNetSpatioTemporalConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and + returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `tuple[int]`, or `tuple[tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: tuple[str, ...] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: tuple[str, ...] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: tuple[int, ...] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: int | tuple[int] = 2, + cross_attention_dim: int | tuple[int] = 1024, + transformer_layers_per_block: int | tuple[int, tuple[tuple]] = 1, + num_attention_heads: int | tuple[int, ...] = (5, 10, 20, 20), + num_frames: int = 25, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + padding=1, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-5, + resolution_idx=i, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + resnet_act_fn="silu", + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) + self.conv_act = nn.SiLU() + + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + ) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: int | None = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + added_time_ids: torch.Tensor, + return_dict: bool = True, + ) -> UNetSpatioTemporalConditionOutput | tuple: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`torch.Tensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead + of a plain tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is + returned, otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(0, 1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave( + num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames + ) + + # 2. pre-process + sample = self.conv_in(sample) + + image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_stable_cascade.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_stable_cascade.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6f24ab794bb899ac496c788463bcc744c91943 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/unet_stable_cascade.py @@ -0,0 +1,583 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import BaseOutput +from ..attention_processor import Attention +from ..modeling_utils import ModelMixin + + +# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm +class SDCascadeLayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) + + +class SDCascadeTimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=[]): + super().__init__() + + self.mapper = nn.Linear(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + + +class SDCascadeResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class SDCascadeAttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + + self.self_attn = self_attn + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x + + +class UpDownBlock2d(nn.Module): + def __init__(self, in_channels, out_channels, mode, enabled=True): + super().__init__() + if mode not in ["up", "down"]: + raise ValueError(f"{mode} not supported") + interpolation = ( + nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) + if enabled + else nn.Identity() + ) + mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +@dataclass +class StableCascadeUNetOutput(BaseOutput): + sample: torch.Tensor = None + + +class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + timestep_ratio_embedding_dim: int = 64, + patch_size: int = 1, + conditioning_dim: int = 2048, + block_out_channels: tuple[int, ...] = (2048, 2048), + num_attention_heads: tuple[int, ...] = (32, 32), + down_num_layers_per_block: tuple[int, ...] = (8, 24), + up_num_layers_per_block: tuple[int, ...] = (24, 8), + down_blocks_repeat_mappers: tuple[int] | None = ( + 1, + 1, + ), + up_blocks_repeat_mappers: tuple[int] | None = (1, 1), + block_types_per_layer: tuple[tuple[str]] = ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + clip_text_in_channels: int | None = None, + clip_text_pooled_in_channels=1280, + clip_image_in_channels: int | None = None, + clip_seq=4, + effnet_in_channels: int | None = None, + pixel_mapper_in_channels: int | None = None, + kernel_size=3, + dropout: float | tuple[float] = (0.1, 0.1), + self_attn: bool | tuple[bool] = True, + timestep_conditioning_type: tuple[str, ...] = ("sca", "crp"), + switch_level: tuple[bool] | None = None, + ): + """ + + Parameters: + in_channels (`int`, defaults to 16): + Number of channels in the input sample. + out_channels (`int`, defaults to 16): + Number of channels in the output sample. + timestep_ratio_embedding_dim (`int`, defaults to 64): + Dimension of the projected time embedding. + patch_size (`int`, defaults to 1): + Patch size to use for pixel unshuffling layer + conditioning_dim (`int`, defaults to 2048): + Dimension of the image and text conditional embedding. + block_out_channels (tuple[int], defaults to (2048, 2048)): + tuple of output channels for each block. + num_attention_heads (tuple[int], defaults to (32, 32)): + Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have + attention. + down_num_layers_per_block (tuple[int], defaults to [8, 24]): + Number of layers in each down block. + up_num_layers_per_block (tuple[int], defaults to [24, 8]): + Number of layers in each up block. + down_blocks_repeat_mappers (tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each down block. + up_blocks_repeat_mappers (tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each up block. + block_types_per_layer (tuple[tuple[str]], optional, + defaults to ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock", + "SDCascadeTimestepBlock", "SDCascadeAttnBlock") + ): Block types used in each layer of the up/down blocks. + clip_text_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for CLIP based text conditioning. + clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280): + Number of input channels for pooled CLIP text embeddings. + clip_image_in_channels (`int`, *optional*): + Number of input channels for CLIP based image conditioning. + clip_seq (`int`, *optional*, defaults to 4): + effnet_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for effnet conditioning. + pixel_mapper_in_channels (`int`, defaults to `None`): + Number of input channels for pixel mapper conditioning. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size to use in the block convolutional layers. + dropout (tuple[float], *optional*, defaults to (0.1, 0.1)): + Dropout to use per block. + self_attn (bool | tuple[bool]): + tuple of booleans that determine whether to use self attention in a block or not. + timestep_conditioning_type (tuple[str], defaults to ("sca", "crp")): + Timestep conditioning type. + switch_level (tuple[bool] | None, *optional*, defaults to `None`): + tuple that indicates whether upsampling or downsampling should be applied in a block + """ + + super().__init__() + + if len(block_out_channels) != len(down_num_layers_per_block): + raise ValueError( + f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_num_layers_per_block): + raise ValueError( + f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(down_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(block_types_per_layer): + raise ValueError( + f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + if isinstance(dropout, float): + dropout = (dropout,) * len(block_out_channels) + if isinstance(self_attn, bool): + self_attn = (self_attn,) * len(block_out_channels) + + # CONDITIONING + if effnet_in_channels is not None: + self.effnet_mapper = nn.Sequential( + nn.Conv2d(effnet_in_channels, block_out_channels[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + if pixel_mapper_in_channels is not None: + self.pixels_mapper = nn.Sequential( + nn.Conv2d(pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1), + nn.GELU(), + nn.Conv2d(block_out_channels[0] * 4, block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + self.clip_txt_pooled_mapper = nn.Linear(clip_text_pooled_in_channels, conditioning_dim * clip_seq) + if clip_text_in_channels is not None: + self.clip_txt_mapper = nn.Linear(clip_text_in_channels, conditioning_dim) + if clip_image_in_channels is not None: + self.clip_img_mapper = nn.Linear(clip_image_in_channels, conditioning_dim * clip_seq) + self.clip_norm = nn.LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(in_channels * (patch_size**2), block_out_channels[0], kernel_size=1), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == "SDCascadeResBlock": + return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "SDCascadeAttnBlock": + return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "SDCascadeTimestepBlock": + return SDCascadeTimestepBlock( + in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type + ) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(block_out_channels)): + if i > 0: + self.down_downscalers.append( + nn.Sequential( + SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.Conv2d(block_out_channels[i - 1], block_out_channels[i], kernel_size=2, stride=2), + ) + ) + else: + self.down_downscalers.append(nn.Identity()) + + down_block = nn.ModuleList() + for _ in range(down_num_layers_per_block[i]): + for block_type in block_types_per_layer[i]: + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + dropout=dropout[i], + self_attn=self_attn[i], + ) + down_block.append(block) + self.down_blocks.append(down_block) + + if down_blocks_repeat_mappers is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(down_blocks_repeat_mappers[i] - 1): + block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(block_out_channels))): + if i > 0: + self.up_upscalers.append( + nn.Sequential( + SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.ConvTranspose2d( + block_out_channels[i], block_out_channels[i - 1], kernel_size=2, stride=2 + ), + ) + ) + else: + self.up_upscalers.append(nn.Identity()) + + up_block = nn.ModuleList() + for j in range(up_num_layers_per_block[::-1][i]): + for k, block_type in enumerate(block_types_per_layer[i]): + c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0 + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + c_skip=c_skip, + dropout=dropout[i], + self_attn=self_attn[i], + ) + up_block.append(block) + self.up_blocks.append(up_block) + + if up_blocks_repeat_mappers is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(up_blocks_repeat_mappers[::-1][i] - 1): + block_repeat_mappers.append(nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(block_out_channels[0], out_channels * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + self.gradient_checkpointing = False + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) + nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None + nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None + + if hasattr(self, "effnet_mapper"): + nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + + if hasattr(self, "pixels_mapper"): + nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + + torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, SDCascadeResBlock): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0])) + elif isinstance(block, SDCascadeTimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000): + r = timestep_ratio * max_positions + half_dim = self.config.timestep_ratio_embedding_dim // 2 + + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + + if self.config.timestep_ratio_embedding_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + + return emb.to(dtype=r.dtype) + + def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None): + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( + clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1 + ) + if clip_txt is not None and clip_img is not None: + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_img = self.clip_img_mapper(clip_img).view( + clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1 + ) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + else: + clip = clip_txt_pool + return self.clip_norm(clip) + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = self._gradient_checkpointing_func(block, x) + elif isinstance(block, SDCascadeAttnBlock): + x = self._gradient_checkpointing_func(block, x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = self._gradient_checkpointing_func(block, x, r_embed) + else: + x = self._gradient_checkpointing_func(block) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + else: + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = block(x) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype + x = torch.nn.functional.interpolate( + x.float(), skip.shape[-2:], mode="bilinear", align_corners=True + ) + x = x.to(orig_type) + x = self._gradient_checkpointing_func(block, x, skip) + elif isinstance(block, SDCascadeAttnBlock): + x = self._gradient_checkpointing_func(block, x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = self._gradient_checkpointing_func(block, x, r_embed) + else: + x = self._gradient_checkpointing_func(block, x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + else: + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + orig_type = x.dtype + x = torch.nn.functional.interpolate( + x.float(), skip.shape[-2:], mode="bilinear", align_corners=True + ) + x = x.to(orig_type) + x = block(x, skip) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward( + self, + sample, + timestep_ratio, + clip_text_pooled, + clip_text=None, + clip_img=None, + effnet=None, + pixels=None, + sca=None, + crp=None, + return_dict=True, + ): + if pixels is None: + pixels = sample.new_zeros(sample.size(0), 3, 8, 8) + + # Process the conditioning embeddings + timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio) + for c in self.config.timestep_conditioning_type: + if c == "sca": + cond = sca + elif c == "crp": + cond = crp + else: + cond = None + t_cond = cond or torch.zeros_like(timestep_ratio) + timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1) + clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img) + + # Model Blocks + x = self.embedding(sample) + if hasattr(self, "effnet_mapper") and effnet is not None: + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True) + ) + if hasattr(self, "pixels_mapper"): + x = x + nn.functional.interpolate( + self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True + ) + level_outputs = self._down_encode(x, timestep_ratio_embed, clip) + x = self._up_decode(level_outputs, timestep_ratio_embed, clip) + sample = self.clf(x) + + if not return_dict: + return (sample,) + return StableCascadeUNetOutput(sample=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/uvit_2d.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/uvit_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..836d41a7f9464d185beaa002dea8ea2a7defa081 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/models/unets/uvit_2d.py @@ -0,0 +1,407 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import apply_lora_scale +from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock +from ..attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, +) +from ..embeddings import TimestepEmbedding, get_timestep_embedding +from ..modeling_utils import ModelMixin +from ..normalization import GlobalResponseNorm, RMSNorm +from ..resnet import Downsample2D, Upsample2D + + +class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # global config + hidden_size: int = 1024, + use_bias: bool = False, + hidden_dropout: float = 0.0, + # conditioning dimensions + cond_embed_dim: int = 768, + micro_cond_encode_dim: int = 256, + micro_cond_embed_dim: int = 1280, + encoder_hidden_size: int = 768, + # num tokens + vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded + codebook_size: int = 8192, + # `UVit2DConvEmbed` + in_channels: int = 768, + block_out_channels: int = 768, + num_res_blocks: int = 3, + downsample: bool = False, + upsample: bool = False, + block_num_heads: int = 12, + # `TransformerLayer` + num_hidden_layers: int = 22, + num_attention_heads: int = 16, + # `Attention` + attention_dropout: float = 0.0, + # `FeedForward` + intermediate_size: int = 2816, + # `Norm` + layer_norm_eps: float = 1e-6, + ln_elementwise_affine: bool = True, + sample_size: int = 64, + ): + super().__init__() + + self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) + self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + + self.embed = UVit2DConvEmbed( + in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias + ) + + self.cond_embed = TimestepEmbedding( + micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias + ) + + self.down_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample, + False, + ) + + self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) + self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) + + self.transformer_layers = nn.ModuleList( + [ + BasicTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=hidden_size // num_attention_heads, + dropout=hidden_dropout, + cross_attention_dim=hidden_size, + attention_bias=use_bias, + norm_type="ada_norm_continuous", + ada_norm_continous_conditioning_embedding_dim=hidden_size, + norm_elementwise_affine=ln_elementwise_affine, + norm_eps=layer_norm_eps, + ada_norm_bias=use_bias, + ff_inner_dim=intermediate_size, + ff_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_hidden_layers) + ] + ) + + self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) + + self.up_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample=False, + upsample=upsample, + ) + + self.mlm_layer = ConvMlmLayer( + block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size + ) + + self.gradient_checkpointing = False + + @apply_lora_scale("cross_attention_kwargs") + def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): + encoder_hidden_states = self.encoder_proj(encoder_hidden_states) + encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + + micro_cond_embeds = get_timestep_embedding( + micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) + + pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) + pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) + pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) + + hidden_states = self.embed(input_ids) + + hidden_states = self.down_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) + + hidden_states = self.project_to_hidden_norm(hidden_states) + hidden_states = self.project_to_hidden(hidden_states) + + for layer in self.transformer_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def layer_(*args): + return checkpoint(layer, *args) + + else: + layer_ = layer + + hidden_states = layer_( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, + ) + + hidden_states = self.project_from_hidden_norm(hidden_states) + hidden_states = self.project_from_hidden(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + + hidden_states = self.up_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + logits = self.mlm_layer(hidden_states) + + return logits + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + +class UVit2DConvEmbed(nn.Module): + def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): + super().__init__() + self.embeddings = nn.Embedding(vocab_size, in_channels) + self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) + self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) + + def forward(self, input_ids): + embeddings = self.embeddings(input_ids) + embeddings = self.layer_norm(embeddings) + embeddings = embeddings.permute(0, 3, 1, 2) + embeddings = self.conv(embeddings) + return embeddings + + +class UVitBlock(nn.Module): + def __init__( + self, + channels, + num_res_blocks: int, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample: bool, + upsample: bool, + ): + super().__init__() + + if downsample: + self.downsample = Downsample2D( + channels, + use_conv=True, + padding=0, + name="Conv2d_0", + kernel_size=2, + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + ) + else: + self.downsample = None + + self.res_blocks = nn.ModuleList( + [ + ConvNextBlock( + channels, + layer_norm_eps, + ln_elementwise_affine, + use_bias, + hidden_dropout, + hidden_size, + ) + for i in range(num_res_blocks) + ] + ) + + self.attention_blocks = nn.ModuleList( + [ + SkipFFTransformerBlock( + channels, + block_num_heads, + channels // block_num_heads, + hidden_size, + use_bias, + attention_dropout, + channels, + attention_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_res_blocks) + ] + ) + + if upsample: + self.upsample = Upsample2D( + channels, + use_conv_transpose=True, + kernel_size=2, + padding=0, + name="conv", + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + interpolate=False, + ) + else: + self.upsample = None + + def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): + if self.downsample is not None: + x = self.downsample(x) + + for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): + x = res_block(x, pooled_text_emb) + + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, height * width).permute(0, 2, 1) + x = attention_block( + x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs + ) + x = x.permute(0, 2, 1).view(batch_size, channels, height, width) + + if self.upsample is not None: + x = self.upsample(x) + + return x + + +class ConvNextBlock(nn.Module): + def __init__( + self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 + ): + super().__init__() + self.depthwise = nn.Conv2d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=use_bias, + ) + self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) + self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) + self.channelwise_act = nn.GELU() + self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) + self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) + self.channelwise_dropout = nn.Dropout(hidden_dropout) + self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) + + def forward(self, x, cond_embeds): + x_res = x + + x = self.depthwise(x) + + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + + x = self.channelwise_linear_1(x) + x = self.channelwise_act(x) + x = self.channelwise_norm(x) + x = self.channelwise_linear_2(x) + x = self.channelwise_dropout(x) + + x = x.permute(0, 3, 1, 2) + + x = x + x_res + + scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) + x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] + + return x + + +class ConvMlmLayer(nn.Module): + def __init__( + self, + block_out_channels: int, + in_channels: int, + use_bias: bool, + ln_elementwise_affine: bool, + layer_norm_eps: float, + codebook_size: int, + ): + super().__init__() + self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) + self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) + self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + logits = self.conv2(hidden_states) + return logits diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e99a08d5920c6154e80489bd3bbe8e4bb7b03500 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/components_manager.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/components_manager.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abb327f20616e84434464dd7fc5e83f0cf775d2a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/components_manager.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/mellon_node_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/mellon_node_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a5d0dde98ed747c678ef440a2b4c10e46ba8bc7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/mellon_node_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/modular_pipeline_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/modular_pipeline_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e3bea4f759ab72e2246aa6db7d6acc7dce99250 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/__pycache__/modular_pipeline_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4754ed01ce6aee9c85259fef1774ea96fbd27009 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_flux"] = ["FluxAutoBlocks"] + _import_structure["modular_blocks_flux_kontext"] = ["FluxKontextAutoBlocks"] + _import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_flux import FluxAutoBlocks + from .modular_blocks_flux_kontext import FluxKontextAutoBlocks + from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f23c2bb7a1e61d9f9ad085483cd1eb34ce6dc67e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/before_denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/before_denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0a664475ffbaed2ffade3350c34784d981227cc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/before_denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/decoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/decoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05557d13f07603bd03a931aafff0546699bbecb6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/decoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb12e6717f3a3dea2441ec89414a9717acdf2c3a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/encoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/encoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85f63faf7f90fc6b9829feb4e17d9c9e8f7e0588 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/encoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/inputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ebaf50562a87315b2a81c3266197f3f2362f83b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/inputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_blocks_flux.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_blocks_flux.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09715485abc994dca6550260ecf05adff8934165 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_blocks_flux.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_blocks_flux_kontext.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_blocks_flux_kontext.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a45cc6d2ed3c634ec23fe8e80020bfe296b78d22 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_blocks_flux_kontext.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5d6aaf8ab82e0f31dcd00ce72a876ce404e22bc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/__pycache__/modular_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/before_denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..c28154775f5ad4d7c7b0bb6dd9471f934a9100cd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/before_denoise.py @@ -0,0 +1,618 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...pipelines import FluxPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + height, + width, + vae_scale_factor, + num_inference_steps, + guidance_scale, + sigmas, + device, +): + image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu) + if transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + else: + guidance = None + + return timesteps, num_inference_steps, sigmas, guidance + + +class FluxSetTimestepsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("guidance_scale", default=3.5), + InputParam("latents", type_hint=torch.Tensor), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + scheduler = components.scheduler + transformer = components.transformer + + batch_size = block_state.batch_size * block_state.num_images_per_prompt + timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + block_state.height, + block_state.width, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.guidance_scale, + block_state.sigmas, + block_state.device, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + block_state.sigmas = sigmas + block_state.guidance = guidance + + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("strength", default=0.6), + InputParam("guidance_scale", default=3.5), + InputParam("num_images_per_prompt", default=1), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler + def get_timesteps(scheduler, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + scheduler = components.scheduler + transformer = components.transformer + batch_size = block_state.batch_size * block_state.num_images_per_prompt + timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals( + transformer, + scheduler, + batch_size, + block_state.height, + block_state.width, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.guidance_scale, + block_state.sigmas, + block_state.device, + ) + timesteps, num_inference_steps = self.get_timesteps( + scheduler, num_inference_steps, block_state.strength, block_state.device + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + block_state.sigmas = sigmas + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state + + +class FluxPrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-image generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # TODO: move packing latents code to a patchifier similar to Qwen + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + block_state.latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`," + " `prepare_latents`. Both noise and image latents should already be patchified." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised, can be generated in prepare latent step.", + ), + InputParam( + name="image_latents", + required=True, + type_hint=torch.Tensor, + description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", + ), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="initial_noise", + type_hint=torch.Tensor, + description="The initial random noised used for inpainting denoising.", + ), + ] + + @staticmethod + def check_inputs(image_latents, latents): + if image_latents.shape[0] != latents.shape[0]: + raise ValueError( + f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + ) + + if image_latents.ndim != 3: + raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}") + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents) + + # prepare latent timestep + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + + # make copy of initial_noise + block_state.initial_noise = block_state.latents + + # scale noise + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation.", + ), + OutputParam( + name="img_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the image latents, used for RoPE calculation.", + ), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device, dtype = prompt_embeds.device, prompt_embeds.dtype + block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxKontextRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="image_height"), + InputParam(name="image_width"), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="prompt_embeds"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation.", + ), + OutputParam( + name="img_ids", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the image latents, used for RoPE calculation.", + ), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device, dtype = prompt_embeds.device, prompt_embeds.dtype + block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + img_ids = None + if ( + getattr(block_state, "image_height", None) is not None + and getattr(block_state, "image_width", None) is not None + ): + image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) + image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) + img_ids = FluxPipeline._prepare_latent_image_ids( + None, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + img_ids[..., 0] = 1 + + height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) + + if img_ids is not None: + latent_ids = torch.cat([latent_ids, img_ids], dim=0) + + block_state.img_ids = latent_ids + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/decoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..5da861e78fcba81baa14d4eacadae232a17b324f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/decoders.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL +from ...utils import logging +from ...video_processor import VaeImageProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + +class FluxDecodeStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam("height", default=1024), + InputParam("width", default=1024), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image] | torch.Tensor | np.ndarray, + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if not block_state.output_type == "latent": + latents = block_state.latents + latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor) + latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + else: + block_state.images = block_state.latents + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..babb4a867e5958efcf4921104336999c20b98de0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/denoise.py @@ -0,0 +1,330 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...models import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FluxLoopDenoiser(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "guidance", + required=False, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Prompt embeddings", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled prompt embeddings", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from text sequence needed for RoPE", + ), + InputParam( + "img_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from image sequence needed for RoPE", + ), + ] + + @torch.no_grad() + def __call__( + self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + noise_pred = components.transformer( + hidden_states=block_state.latents, + timestep=t.flatten() / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=block_state.joint_attention_kwargs, + txt_ids=block_state.txt_ids, + img_ids=block_state.img_ids, + return_dict=False, + )[0] + block_state.noise_pred = noise_pred + + return components, block_state + + +class FluxKontextLoopDenoiser(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents for Flux Kontext. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Image latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "guidance", + required=False, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Prompt embeddings", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled prompt embeddings", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from text sequence needed for RoPE", + ), + InputParam( + "img_ids", + required=True, + type_hint=torch.Tensor, + description="IDs computed from latent sequence needed for RoPE", + ), + ] + + @torch.no_grad() + def __call__( + self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents + image_latents = block_state.image_latents + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=block_state.joint_attention_kwargs, + txt_ids=block_state.txt_ids, + img_ids=block_state.img_ids, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +class FluxLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `FluxDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> list[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", FluxTransformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class FluxDenoiseStep(FluxDenoiseLoopWrapper): + block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `FluxLoopDenoiser`\n" + " - `FluxLoopAfterDenoiser`\n" + "This block supports both text2image and img2img tasks." + ) + + +class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper): + model_name = "flux-kontext" + block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `FluxKontextLoopDenoiser`\n" + " - `FluxLoopAfterDenoiser`\n" + "This block supports both text2image and img2img tasks." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/encoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..583c139ff22e60173efc31a97654b243e8576d1c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/encoders.py @@ -0,0 +1,480 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import regex as re +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL +from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline + + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class FluxProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Image Preprocess step." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam(name="processed_image")] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("`resized_image` and `image` cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + + self.set_block_state(state, block_state) + return components, state + + +class FluxKontextProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def description(self) -> str: + return ( + "Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n" + "Kontext works as a T2I model, too, in case no input image is provided." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam(name="processed_image")] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState): + from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS + + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.processed_image = None + + else: + multiple_of = components.image_processor.config.vae_scale_factor + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + img = images[0] + image_height, image_width = components.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + _auto_resize = block_state._auto_resize + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + images = components.image_processor.resize(images, image_height, image_width) + block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width) + + self.set_block_state(state, block_state) + return components, state + + +class FluxVaeEncoderStep(ModularPipelineBlocks): + model_name = "flux" + + def __init__( + self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample" + ): + """Initialize a VAE encoder step for converting images to latent representations. + + Both the input and output names are configurable so this block can be configured to process to different image + inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents"). + + Args: + input_name (str, optional): Name of the input image tensor. Defaults to "processed_image". + Examples: "processed_image" or "processed_control_image" + output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents". + Examples: "image_latents" or "control_image_latents" + sample_mode (str, optional): Sampling mode to be used. + + Examples: + # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep() + + # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep( + input_name="processed_control_image", output_name="control_image_latents" + ) + """ + self._image_input_name = input_name + self._image_latents_output_name = output_name + self.sample_mode = sample_mode + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ComponentSpec("vae", AutoencoderKL)] + return components + + @property + def inputs(self) -> list[InputParam]: + inputs = [InputParam(self._image_input_name), InputParam("generator")] + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + self._image_latents_output_name, + type_hint=torch.Tensor, + description="The latents representing the reference image", + ) + ] + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image = getattr(block_state, self._image_input_name) + + if image is None: + setattr(block_state, self._image_latents_output_name, None) + else: + device = components._execution_device + dtype = components.vae.dtype + image = image.to(device=device, dtype=dtype) + + # Encode image into latents + image_latents = encode_vae_image( + image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode + ) + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxTextEncoderStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("text_encoder_2", T5EncoderModel), + ComponentSpec("tokenizer_2", T5TokenizerFast), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="pooled text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + for prompt in [block_state.prompt, block_state.prompt_2]: + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_t5_prompt_embeds(components, prompt: str | list[str], max_sequence_length: int, device: torch.device): + dtype = components.text_encoder_2.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2) + + text_inputs = components.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + @staticmethod + def _get_clip_prompt_embeds(components, prompt: str | list[str], device: torch.device): + prompt = [prompt] if isinstance(prompt, str) else prompt + + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, components.tokenizer) + + text_inputs = components.tokenizer( + prompt, + padding="max_length", + max_length=components.tokenizer.model_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + tokenizer_max_length = components.tokenizer.model_max_length + untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) + + return prompt_embeds + + @staticmethod + def encode_prompt( + components, + prompt: str | list[str], + prompt_2: str | list[str], + device: torch.device | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds( + components, + prompt=prompt, + device=device, + ) + prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds( + components, + prompt=prompt_2, + max_sequence_length=max_sequence_length, + device=device, + ) + + if components.text_encoder is not None: + if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.joint_attention_kwargs.get("scale", None) + if block_state.joint_attention_kwargs is not None + else None + ) + block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt( + components, + prompt=block_state.prompt, + prompt_2=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + lora_scale=block_state.text_encoder_lora_scale, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/inputs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2f69dbe26fb0fdf1fb3c96ed2fb54b181b023f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/inputs.py @@ -0,0 +1,363 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...pipelines import FluxPipeline +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam + +# TODO: consider making these common utilities for modular if they are not pipeline-specific. +from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size +from .modular_pipeline import FluxModularPipeline + + +logger = logging.get_logger(__name__) + + +class FluxTextInputStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Text input processing step that standardizes text embeddings for the pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + # TODO: support negative embeddings? + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + # TODO: support negative embeddings? + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: + if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" + f" {block_state.pooled_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # TODO: consider adding negative embeddings? + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + block_state.pooled_prompt_embeds = pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + self.set_block_state(state, block_state) + + return components, state + + +# Adapted from `QwenImageAdditionalInputsStep` +class FluxAdditionalInputsStep(ModularPipelineBlocks): + model_name = "flux" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_latents"], + additional_batch_inputs: list[str] = [], + ): + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="image_height", type_hint=int, description="The height of the image latents"), + OutputParam(name="image_width", type_hint=int, description="The width of the image latents"), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify the image latent tensor + # TODO: Implement patchifier for Flux. + latent_height, latent_width = image_latent_tensor.shape[2:] + image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width + ) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class FluxKontextAdditionalInputsStep(FluxAdditionalInputsStep): + model_name = "flux-kontext" + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents + # Unlike the `FluxAdditionalInputsStep`, we don't overwrite the `block.height` and `block.width` + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify the image latent tensor + # TODO: Implement patchifier for Flux. + latent_height, latent_width = image_latent_tensor.shape[2:] + image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width + ) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class FluxKontextSetResolutionStep(ModularPipelineBlocks): + model_name = "flux-kontext" + + @property + def description(self): + return ( + "Determines the height and width to be used during the subsequent computations.\n" + "It should always be placed _before_ the latent preparation step." + ) + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="max_area", type_hint=int, default=1024**2), + ] + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"), + OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + self.check_inputs(height, width, components.vae_scale_factor) + + original_height, original_width = height, width + max_area = block_state.max_area + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = components.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + block_state.height = height + block_state.width = width + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_blocks_flux.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_blocks_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e78e9334482234f2473403c2d53f8a6407f835 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_blocks_flux.py @@ -0,0 +1,586 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + FluxImg2ImgPrepareLatentsStep, + FluxImg2ImgSetTimestepsStep, + FluxPrepareLatentsStep, + FluxRoPEInputsStep, + FluxSetTimestepsStep, +) +from .decoders import FluxDecodeStep +from .denoise import FluxDenoiseStep +from .encoders import ( + FluxProcessImagesInputStep, + FluxTextEncoderStep, + FluxVaeEncoderStep, +) +from .inputs import ( + FluxAdditionalInputsStep, + FluxTextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# vae encoder (run before before_denoise) + + +# auto_docstring +class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux" + + block_classes = [FluxProcessImagesInputStep(), FluxVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +# auto_docstring +class FluxAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block that works for img2img tasks. + - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided. - if `image` is not provided, + step will be skipped. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux" + block_classes = [FluxImg2ImgVaeEncoderStep] + block_names = ["img2img"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for img2img tasks.\n" + + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if `image` is not provided, step will be skipped." + ) + + +# before_denoise: text2img +# auto_docstring +class FluxBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepares the inputs for the denoise step in text-to-image generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux" + block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()] + block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"] + + @property + def description(self): + return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation." + + +# before_denoise: img2img +# auto_docstring +class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step for img2img task. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_latents (`Tensor`): + The image latents to use for the denoising process. Can be generated in vae encoder and packed in input + step. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux" + block_classes = [ + FluxPrepareLatentsStep(), + FluxImg2ImgSetTimestepsStep(), + FluxImg2ImgPrepareLatentsStep(), + FluxRoPEInputsStep(), + ] + block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents", "prepare_rope_inputs"] + + @property + def description(self): + return "Before denoise step that prepare the inputs for the denoise step for img2img task." + + +# before_denoise: all task (text2img, img2img) +# auto_docstring +class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step. + This is an auto pipeline block that works for text2image. + - `FluxBeforeDenoiseStep` (text2image) is used. + - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`): + TODO: Add description. + width (`int`): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_latents (`Tensor`, *optional*): + The image latents to use for the denoising process. Can be generated in vae encoder and packed in input + step. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux" + block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2image.\n" + + " - `FluxBeforeDenoiseStep` (text2image) is used.\n" + + " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + ) + + +# inputs: text2image/img2img + + +# auto_docstring +class FluxImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux" + block_classes = [FluxTextInputStep(), FluxAdditionalInputsStep()] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# auto_docstring +class FluxAutoInputStep(AutoPipelineBlocks): + """ + Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, + and patchified. + This is an auto pipeline block that works for text2image/img2img tasks. + - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided. + - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided. + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux" + + block_classes = [FluxImg2ImgInputStep, FluxTextInputStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" + " This is an auto pipeline block that works for text2image/img2img tasks.\n" + + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n" + + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n" + ) + + +# auto_docstring +class FluxCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core step that performs the denoising process for Flux. + This step supports text-to-image and image-to-image tasks for Flux: + - for image-to-image generation, you need to provide `image_latents` + - for text-to-image generation, all you need to provide is prompt embeddings. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux" + block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep] + block_names = ["input", "before_denoise", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process for Flux.\n" + + "This step supports text-to-image and image-to-image tasks for Flux:\n" + + " - for image-to-image generation, you need to provide `image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto blocks (text2image and img2img) +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep()), + ("vae_encoder", FluxAutoVaeEncoderStep()), + ("denoise", FluxCoreDenoiseStep()), + ("decode", FluxDecodeStep()), + ] +) + + +# auto_docstring +class FluxAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image and image-to-image using Flux. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + + Components: + text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2 + (`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + prompt_2 (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + resized_image (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image and image-to-image using Flux." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_blocks_flux_kontext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_blocks_flux_kontext.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a5dbf78c0eb9ef2cd7a8102aca2a29aa9b6db3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_blocks_flux_kontext.py @@ -0,0 +1,585 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + FluxKontextRoPEInputsStep, + FluxPrepareLatentsStep, + FluxRoPEInputsStep, + FluxSetTimestepsStep, +) +from .decoders import FluxDecodeStep +from .denoise import FluxKontextDenoiseStep +from .encoders import ( + FluxKontextProcessImagesInputStep, + FluxTextEncoderStep, + FluxVaeEncoderStep, +) +from .inputs import ( + FluxKontextAdditionalInputsStep, + FluxKontextSetResolutionStep, + FluxTextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Flux Kontext vae encoder (run before before_denoise) +# auto_docstring +class FluxKontextVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + _auto_resize (`bool`, *optional*, defaults to True): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux-kontext" + + block_classes = [FluxKontextProcessImagesInputStep(), FluxVaeEncoderStep(sample_mode="argmax")] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +# auto_docstring +class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block that works for image-conditioned tasks. + - `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided. - if `image` is not + provided, step will be skipped. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + _auto_resize (`bool`, *optional*, defaults to True): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + + model_name = "flux-kontext" + + block_classes = [FluxKontextVaeEncoderStep] + block_names = ["image_conditioned"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for image-conditioned tasks.\n" + + " - `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided." + + " - if `image` is not provided, step will be skipped." + ) + + +# before_denoise: text2img +# auto_docstring +class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepares the inputs for the denoise step for Flux Kontext + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux-kontext" + + block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()] + block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"] + + @property + def description(self): + return "Before denoise step that prepares the inputs for the denoise step for Flux Kontext\n" + "for text-to-image tasks." + + +# before_denoise: image-conditioned +# auto_docstring +class FluxKontextImageConditionedBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step for Flux Kontext + for image-conditioned tasks. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_height (`None`, *optional*): + TODO: Add description. + image_width (`None`, *optional*): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux-kontext" + + block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxKontextRoPEInputsStep()] + block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for Flux Kontext\n" + "for image-conditioned tasks." + ) + + +# auto_docstring +class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks): + """ + Before denoise step that prepare the inputs for the denoise step. + This is an auto pipeline block that works for text2image. + - `FluxKontextBeforeDenoiseStep` (text2image) is used. + - `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is + provided. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_images_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. + Can be generated in input step. + dtype (`dtype`, *optional*): + The dtype of the model inputs + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + image_height (`None`, *optional*): + TODO: Add description. + image_width (`None`, *optional*): + TODO: Add description. + prompt_embeds (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + The initial latents to use for the denoising process + timesteps (`Tensor`): + The timesteps to use for inference + num_inference_steps (`int`): + The number of denoising steps to perform at inference time + guidance (`Tensor`): + Optional guidance to be used. + txt_ids (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation. + img_ids (`list`): + The sequence lengths of the image latents, used for RoPE calculation. + """ + + model_name = "flux-kontext" + + block_classes = [FluxKontextImageConditionedBeforeDenoiseStep, FluxKontextBeforeDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2image.\n" + + " - `FluxKontextBeforeDenoiseStep` (text2image) is used.\n" + + " - `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is provided.\n" + ) + + +# inputs: Flux Kontext +# auto_docstring +class FluxKontextInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the both text2img and img2img denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`). + - update height/width based `image_latents`, patchify `image_latents`. + + Inputs: + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + height (`int`): + The height of the initial noisy latents + width (`int`): + The width of the initial noisy latents + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux-kontext" + block_classes = [FluxKontextSetResolutionStep(), FluxTextInputStep(), FluxKontextAdditionalInputsStep()] + block_names = ["set_resolution", "text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# auto_docstring +class FluxKontextAutoInputStep(AutoPipelineBlocks): + """ + Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, + and patchified. + This is an auto pipeline block that works for text2image/img2img tasks. + - `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided. + - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present. + + Inputs: + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + height (`int`): + The height of the initial noisy latents + width (`int`): + The width of the initial noisy latents + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation + pooled_prompt_embeds (`Tensor`): + pooled text embeddings used to guide the image generation + image_height (`int`): + The height of the image latents + image_width (`int`): + The width of the image latents + """ + + model_name = "flux-kontext" + block_classes = [FluxKontextInputStep, FluxTextInputStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" + " This is an auto pipeline block that works for text2image/img2img tasks.\n" + + " - `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided.\n" + + " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present." + ) + + +# auto_docstring +class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core step that performs the denoising process for Flux Kontext. + This step supports text-to-image and image-conditioned tasks for Flux Kontext: + - for image-conditioned generation, you need to provide `image_latents` + - for text-to-image generation, all you need to provide is prompt embeddings. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be generated from text_encoder step. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux-kontext" + block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep] + block_names = ["input", "before_denoise", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process for Flux Kontext.\n" + + "This step supports text-to-image and image-conditioned tasks for Flux Kontext:\n" + + " - for image-conditioned generation, you need to provide `image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +AUTO_BLOCKS_KONTEXT = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep()), + ("vae_encoder", FluxKontextAutoVaeEncoderStep()), + ("denoise", FluxKontextCoreDenoiseStep()), + ("decode", FluxDecodeStep()), + ] +) + + +# auto_docstring +class FluxKontextAutoBlocks(SequentialPipelineBlocks): + """ + Modular pipeline for image-to-image using Flux Kontext. + + Supported workflows: + - `image_conditioned`: requires `image`, `prompt` + - `text2image`: requires `prompt` + + Components: + text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2 + (`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + prompt_2 (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + _auto_resize (`bool`, *optional*, defaults to True): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + max_area (`int`, *optional*, defaults to 1048576): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 3.5): + TODO: Add description. + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux-kontext" + + block_classes = AUTO_BLOCKS_KONTEXT.values() + block_names = AUTO_BLOCKS_KONTEXT.keys() + _workflow_map = { + "image_conditioned": {"image": True, "prompt": True}, + "text2image": {"prompt": True}, + } + + @property + def description(self): + return "Modular pipeline for image-to-image using Flux Kontext." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_pipeline.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d8158f5d4fd618751e328bb92f699b818507ba05 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux/modular_pipeline.py @@ -0,0 +1,67 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin): + """ + A ModularPipeline for Flux. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "FluxAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + +class FluxKontextModularPipeline(FluxModularPipeline): + """ + A ModularPipeline for Flux Kontext. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "FluxKontextAutoBlocks" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cc8badcaf7d0e51f7e6eb2923df6c04d4172cd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["Flux2RemoteTextEncoderStep"] + _import_structure["modular_blocks_flux2"] = ["Flux2AutoBlocks"] + _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks"] + _import_structure["modular_blocks_flux2_klein_base"] = ["Flux2KleinBaseAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "Flux2KleinBaseModularPipeline", + "Flux2KleinModularPipeline", + "Flux2ModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import Flux2RemoteTextEncoderStep + from .modular_blocks_flux2 import Flux2AutoBlocks + from .modular_blocks_flux2_klein import Flux2KleinAutoBlocks + from .modular_blocks_flux2_klein_base import Flux2KleinBaseAutoBlocks + from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..054b7b4ba0b3251da5bd4fd713aa9649525f8ff5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/before_denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/before_denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de7b47658f530206ac2eeac297999fb28da1ee2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/before_denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/decoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/decoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6c31531b6281fda728bff1342b951bdba0491fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/decoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04c38f8f8dd8d769e013961b526bb22887482280 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/encoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/encoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce51aecf595387f7766cfa5c3a7b6f24fb2983a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/encoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/inputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61f0f5284ebe28ea0f4186f9da63b7d8e6abe4b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/inputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60c157f621895a4dd054568b6f413eaf5d64ceae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2_klein.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2_klein.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49684637b6579eed03f6b564773189efb07745d0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2_klein.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2_klein_base.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2_klein_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41c7c05eb1fb82040e0d8a73dd29cfae557bfeac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_blocks_flux2_klein_base.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e65de2b8471c16bf526cf7ec17ee6f05b300600 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/__pycache__/modular_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/before_denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..6b1b3bd96324dd0fc90a293c18246a18e2133541 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/before_denoise.py @@ -0,0 +1,591 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + """Compute empirical mu for Flux2 timestep scheduling.""" + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Flux2SetTimestepsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("latents", type_hint=torch.Tensor), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + scheduler = components.scheduler + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + vae_scale_factor = components.vae_scale_factor + + latent_height = 2 * (int(height) // (vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + num_inference_steps = block_state.num_inference_steps + sigmas = block_state.sigmas + timesteps = block_state.timesteps + + if timesteps is None and sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + timesteps=timesteps, + sigmas=sigmas, + mu=mu, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"), + ] + + @staticmethod + def check_inputs(components, block_state): + vae_scale_factor = components.vae_scale_factor + if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def _prepare_latent_ids(latents: torch.Tensor): + """ + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents: Latent tensor of shape (B, C, H, W) + + Returns: + Position IDs tensor of shape (B, H*W, 4) + """ + batch_size, _, height, width = latents.shape + + t = torch.arange(1) + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) + + latent_ids = torch.cartesian_prod(t, h, w, l) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(block_state.device) + + latents = self._pack_latents(latents) + + block_state.latents = latents + block_state.latent_ids = latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: torch.Tensor | None = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + InputParam(name="negative_prompt_embeds", required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + OutputParam( + name="negative_txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: torch.Tensor | None = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + block_state.negative_txt_ids = None + if block_state.negative_prompt_embeds is not None: + block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) + block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareImageLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares image latents and their position IDs for Flux2 image conditioning." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image_latents", type_hint=list[torch.Tensor]), + InputParam("batch_size", required=True, type_hint=int), + InputParam("num_images_per_prompt", default=1, type_hint=int), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning", + ), + OutputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents", + ), + ] + + @staticmethod + def _prepare_image_ids(image_latents: list[torch.Tensor], scale: int = 10): + """ + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + Args: + image_latents: A list of image latent feature tensors of shape (1, C, H, W). + scale: Factor used to define the time separation between latents. + + Returns: + Combined coordinate tensor of shape (1, N_total, 4) + """ + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image_latents = block_state.image_latents + + if image_latents is None: + block_state.image_latents = None + block_state.image_latent_ids = None + self.set_block_state(state, block_state) + + return components, state + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + image_latent_ids = self._prepare_image_ids(image_latents) + + packed_latents = [] + for latent in image_latents: + packed = self._pack_latents(latent) + packed = packed.squeeze(0) + packed_latents.append(packed) + + image_latents = torch.cat(packed_latents, dim=0) + image_latents = image_latents.unsqueeze(0) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + block_state.image_latents = image_latents + block_state.image_latent_ids = image_latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareGuidanceStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the guidance scale tensor for Flux2 inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("guidance_scale", default=4.0), + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/decoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ad9401efff6ef4bd714c294cff064353e362d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/decoders.py @@ -0,0 +1,185 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2UnpackLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that unpacks the latents from the denoising step" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="Position IDs for the latents, used for unpacking", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The denoise latents from denoising step, unpacked with position IDs.", + ) + ] + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor: + """ + Unpack latents using position IDs to scatter tokens into place. + + Args: + x: Packed latents tensor of shape (B, seq_len, C) + x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates + + Returns: + Unpacked latents tensor of shape (B, C, H, W) + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + block_state.latents = latents + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=Union[list[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + def _unpatchify_latents(latents): + """Convert patchified latents back to regular format.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + latents = block_state.latents + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + + latents = self._unpatchify_latents(latents) + + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..66783cc36953ad69cd13353fdd1648e3f1a5d221 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/denoise.py @@ -0,0 +1,509 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2LoopDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# same as Flux2LoopDenoiser but guidance=None +class Flux2KleinLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# support CFG for Flux2-Klein base model +class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", Flux2Transformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "negative_prompt_embeds", + required=False, + type_hint=torch.Tensor, + description="Negative text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "negative_txt_ids", + required=False, + type_hint=torch.Tensor, + description="4D position IDs for negative text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "txt_ids": ( + getattr(block_state, "txt_ids", None), + getattr(block_state, "negative_txt_ids", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] + components.guider.cleanup_models(components.transformer) + + # perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class Flux2LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that updates the latents after denoising. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> list[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attribute" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process.", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2LoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinBaseLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/encoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..81d20a8f4c6532db2b04d1423e77c1a9f1b0347d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/encoders.py @@ -0,0 +1,608 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLFlux2 +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def format_text_input(prompts: list[str], system_message: str = None): + """Format prompts for Mistral3 chat template.""" + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2TextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + # fmt: off + DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Mistral3ForConditionalGeneration), + ComponentSpec("tokenizer", AutoProcessor), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=tuple[int], default=(10, 20, 30), required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from Mistral3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_mistral_3_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: tuple[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_mistral_3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + system_message=self.DEFAULT_SYSTEM_MESSAGE, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using a remote API endpoint" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from remote API used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + import io + + import requests + from huggingface_hub import get_token + + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + response = requests.post( + self.REMOTE_URL, + json={"prompt": prompt}, + headers={ + "Authorization": f"Bearer {get_token()}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + + block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True) + block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=True), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Negative text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + if components.requires_unconditional_embeds: + negative_prompt = [""] * len(prompt) + block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + else: + block_state.negative_prompt_embeds = None + + self.set_block_state(state, block_state) + return components, state + + +class Flux2VaeEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLFlux2)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("condition_images", type_hint=list[torch.Tensor]), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=list[torch.Tensor], + description="List of latent representations for each reference image", + ), + ] + + @staticmethod + def _patchify_latents(latents): + """Convert latents to patchified format for Flux2.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator): + """Encode a single image using Flux2 VAE with batch norm normalization.""" + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) + latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + condition_images = block_state.condition_images + + if condition_images is None: + return components, state + + device = components._execution_device + dtype = components.vae.dtype + + image_latents = [] + for image in condition_images: + image = image.to(device=device, dtype=dtype) + latent = self._encode_vae_image( + vae=components.vae, + image=image, + generator=block_state.generator, + ) + image_latents.append(latent) + + block_state.image_latents = image_latents + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/inputs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6881f70a46c6948d04e9c9d980ee194da683c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/inputs.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import FrozenDict +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) + + +class Flux2TextInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextInputStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + required=False, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Image preprocess step for Flux2. Validates and preprocesses reference images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam(name="condition_images", type_hint=list[torch.Tensor])] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.condition_images = None + self.set_block_state(state, block_state) + return components, state + + if not isinstance(images, list): + images = [images] + + condition_images = [] + for img in images: + components.image_processor.check_image_input(img) + + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = components.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = components.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + condition_img = components.image_processor.preprocess( + img, height=image_height, width=image_width, resize_mode="crop" + ) + condition_images.append(condition_img) + + if block_state.height is None: + block_state.height = image_height + if block_state.width is None: + block_state.width = image_width + + block_state.condition_images = condition_images + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..b1033a7dff9e868dabbd078822c814be0d9fc546 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -0,0 +1,356 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2PrepareGuidanceStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2DenoiseStep +from .encoders import ( + Flux2TextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto_docstring +class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks): + """ + VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2" + + block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning." + + +# auto_docstring +class Flux2AutoVaeEncoderStep(AutoPipelineBlocks): + """ + VAE encoder step that encodes the image inputs into their latent representations. + This is an auto pipeline block that works for image conditioning tasks. + - `Flux2VaeEncoderSequentialStep` is used when `image` is provided. + - If `image` is not provided, step will be skipped. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + block_classes = [Flux2VaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +Flux2CoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2CoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-dev. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 4.0): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latents (`Tensor`, *optional*): + Packed image latents for conditioning. Shape: (B, img_seq_len, C) + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2" + + block_classes = Flux2CoreDenoiseBlocks.values() + block_names = Flux2CoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-dev." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +Flux2ImageConditionedCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2ImageConditionedCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-dev with image conditioning. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + image_latents (`list`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 4.0): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2" + + block_classes = Flux2ImageConditionedCoreDenoiseBlocks.values() + block_names = Flux2ImageConditionedCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-dev with image conditioning." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +class Flux2AutoCoreDenoiseStep(AutoPipelineBlocks): + model_name = "flux2" + + block_classes = [Flux2ImageConditionedCoreDenoiseStep, Flux2CoreDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Auto core denoise step that performs the denoising process for Flux2-dev." + "This is an auto pipeline block that works for text-to-image and image-conditioned generation." + " - `Flux2CoreDenoiseStep` is used for text-to-image generation.\n" + " - `Flux2ImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n" + ) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2AutoCoreDenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +# auto_docstring +class Flux2AutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2. + + Supported workflows: + - `text2image`: requires `prompt` + - `image_conditioned`: requires `image`, `prompt` + + Components: + text_encoder (`Mistral3ForConditionalGeneration`) tokenizer (`AutoProcessor`) image_processor + (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer + (`Flux2Transformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + text_encoder_out_layers (`tuple`, *optional*, defaults to (10, 20, 30)): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + guidance_scale (`None`, *optional*, defaults to 4.0): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux2" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + _workflow_map = { + "text2image": {"prompt": True}, + "image_conditioned": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbae43a5a7fd6c4df8ebb6314ed8c20d7c3372f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -0,0 +1,400 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2KleinDenoiseStep +from .encoders import ( + Flux2KleinTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +################ +# VAE encoder +################ + + +# auto_docstring +class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): + """ + VAE encoder step that preprocesses and encodes the image inputs into their latent representations. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2-klein" + + block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + +# auto_docstring +class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): + """ + VAE encoder step that encodes the image inputs into their latent representations. + This is an auto pipeline block that works for image conditioning tasks. + - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided. + - If `image` is not provided, step will be skipped. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2-klein" + + block_classes = [Flux2KleinVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +### +### Core denoise +### + +Flux2KleinCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image + generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latents (`Tensor`, *optional*): + Packed image latents for conditioning. Shape: (B, img_seq_len, C) + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + + block_classes = Flux2KleinCoreDenoiseBlocks.values() + block_names = Flux2KleinCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image generation." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +Flux2KleinImageConditionedCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinImageConditionedCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + image_latents (`list`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + + block_classes = Flux2KleinImageConditionedCoreDenoiseBlocks.values() + block_names = Flux2KleinImageConditionedCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# auto_docstring +class Flux2KleinAutoCoreDenoiseStep(AutoPipelineBlocks): + """ + Auto core denoise step that performs the denoising process for Flux2-Klein. + This is an auto pipeline block that works for text-to-image and image-conditioned generation. + - `Flux2KleinCoreDenoiseStep` is used for text-to-image generation. + - `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + image_latents (`list`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = [Flux2KleinImageConditionedCoreDenoiseStep, Flux2KleinCoreDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Auto core denoise step that performs the denoising process for Flux2-Klein.\n" + "This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n" + " - `Flux2KleinCoreDenoiseStep` is used for text-to-image generation.\n" + " - `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n" + ) + + +### +### Auto blocks +### + + +# auto_docstring +class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + """ + Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein. + + Supported workflows: + - `text2image`: requires `prompt` + - `image_conditioned`: requires `image`, `prompt` + + Components: + text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) image_processor (`Flux2ImageProcessor`) + vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer + (`Flux2Transformer2DModel`) + + Configs: + is_distilled (default: True) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux2-klein" + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinAutoCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + "image_conditioned": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein_base.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein_base.py new file mode 100644 index 0000000000000000000000000000000000000000..42e025c622b4e04602f5ab7dd13de46bdf555623 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein_base.py @@ -0,0 +1,413 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2KleinBaseRoPEInputsStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2KleinBaseDenoiseStep +from .encoders import ( + Flux2KleinBaseTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2KleinBaseTextInputStep, + Flux2ProcessImagesInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +################ +# VAE encoder +################ + + +# auto_docstring +class Flux2KleinBaseVaeEncoderSequentialStep(SequentialPipelineBlocks): + """ + VAE encoder step that preprocesses and encodes the image inputs into their latent representations. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + model_name = "flux2" + + block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + +# auto_docstring +class Flux2KleinBaseAutoVaeEncoderStep(AutoPipelineBlocks): + """ + VAE encoder step that encodes the image inputs into their latent representations. + This is an auto pipeline block that works for image conditioning tasks. + - `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided. + - If `image` is not provided, step will be skipped. + + Components: + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) + + Inputs: + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + condition_images (`list`): + TODO: Add description. + image_latents (`list`): + List of latent representations for each reference image + """ + + block_classes = [Flux2KleinBaseVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +### +### Core denoise +### + +Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2KleinBaseTextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Configs: + is_distilled (default: False) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latents (`Tensor`, *optional*): + Packed image latents for conditioning. Shape: (B, img_seq_len, C) + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +Flux2KleinBaseImageConditionedCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2KleinBaseTextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +# auto_docstring +class Flux2KleinBaseImageConditionedCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Configs: + is_distilled (default: False) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# auto_docstring +class Flux2KleinBaseAutoCoreDenoiseStep(AutoPipelineBlocks): + """ + Auto core denoise step that performs the denoising process for Flux2-Klein (base model). + This is an auto pipeline block that works for text-to-image and image-conditioned generation. + - `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation. + - `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Configs: + is_distilled (default: False) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "flux2-klein" + block_classes = [Flux2KleinBaseImageConditionedCoreDenoiseStep, Flux2KleinBaseCoreDenoiseStep] + block_names = ["image_conditioned", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self): + return ( + "Auto core denoise step that performs the denoising process for Flux2-Klein (base model).\n" + "This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n" + " - `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation.\n" + " - `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n" + ) + + +### +### Auto blocks +### + + +# auto_docstring +class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + """ + Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model). + + Supported workflows: + - `text2image`: requires `prompt` + - `image_conditioned`: requires `image`, `prompt` + + Components: + text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) guider (`ClassifierFreeGuidance`) + image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) + + Configs: + is_distilled (default: False) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`int`, *optional*, defaults to 512): + TODO: Add description. + text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + image_latents (`list`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + timesteps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + joint_attention_kwargs (`None`, *optional*): + TODO: Add description. + image_latent_ids (`Tensor`, *optional*): + Position IDs for image latents. Shape: (B, img_seq_len, 4) + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "flux2-klein" + block_classes = [ + Flux2KleinBaseTextEncoderStep(), + Flux2KleinBaseAutoVaeEncoderStep(), + Flux2KleinBaseAutoCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + "image_conditioned": {"image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model)." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_pipeline.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..31ba5aec7cfbeedf1d842e918650da0243719e25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -0,0 +1,99 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import Flux2LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + +class Flux2KleinModularPipeline(Flux2ModularPipeline): + """ + A ModularPipeline for Flux2-Klein (distilled model). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinAutoBlocks" + + @property + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class Flux2KleinBaseModularPipeline(Flux2ModularPipeline): + """ + A ModularPipeline for Flux2-Klein (base model). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" + + @property + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6af4495b376219a6bb6bcfaeb3b70bace9d89d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__init__.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_qwenimage"] = ["QwenImageAutoBlocks"] + _import_structure["modular_blocks_qwenimage_edit"] = ["QwenImageEditAutoBlocks"] + _import_structure["modular_blocks_qwenimage_edit_plus"] = ["QwenImageEditPlusAutoBlocks"] + _import_structure["modular_blocks_qwenimage_layered"] = ["QwenImageLayeredAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "QwenImageEditModularPipeline", + "QwenImageEditPlusModularPipeline", + "QwenImageLayeredModularPipeline", + "QwenImageModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_qwenimage import QwenImageAutoBlocks + from .modular_blocks_qwenimage_edit import QwenImageEditAutoBlocks + from .modular_blocks_qwenimage_edit_plus import QwenImageEditPlusAutoBlocks + from .modular_blocks_qwenimage_layered import QwenImageLayeredAutoBlocks + from .modular_pipeline import ( + QwenImageEditModularPipeline, + QwenImageEditPlusModularPipeline, + QwenImageLayeredModularPipeline, + QwenImageModularPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2cd806f489598e5a3c5c7bd2397eb7200fbb7dc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/before_denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/before_denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..067254882df7b15bec474541845461e8bca77c73 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/before_denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/decoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/decoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b30bb5b16cce4f5a2e1daab1e7ce31745627822 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/decoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bbc272b90ed23b11046400d521657d75c330974 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/encoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/encoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2a65e6d08c0868bc34c555e5671b355786c6efe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/encoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/inputs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/inputs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..487f30cc07efe3345f2feb6b2daa2329935abb8b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/inputs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78f39154f485e36853cfde325c022aee7c9a9f50 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_edit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_edit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4598f20699b8f96e7102a873c73b46e49d3eae63 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_edit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_edit_plus.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_edit_plus.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f7185b5560f1e4b6f8c0d2227c2099323a8c5fe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_edit_plus.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_layered.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_layered.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d9f2d7927398b9c2b9ddb7e41b1af355e3402d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_blocks_qwenimage_layered.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d7385e2656605de99c123cb93b1c5345e159127 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/modular_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/prompt_templates.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/prompt_templates.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..079b360fcc82b2ea882ac480dee4aea1c61c673f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/__pycache__/prompt_templates.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/before_denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..51b5c6ac8c3db39fa8b1cf8c333a803fb9bba2a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -0,0 +1,1330 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils.torch_utils import randn_tensor, unwrap_module +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps +def get_timesteps(scheduler, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + + return timesteps, num_inference_steps - t_start + + +# ==================== +# 1. PREPARE LATENTS +# ==================== + + +# auto_docstring +class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise for the generation process + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Prepare initial random noise for the generation process" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise (B, layers+1, C, H, W) for the generation process + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("layers"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + """ + Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, + prepare_latents. Both noise and image latents should alreadybe patchified. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised, can be generated in prepare latent step.", + ), + InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="initial_noise", + type_hint=torch.Tensor, + description="The initial random noised used for inpainting denoising.", + ), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The scaled noisy latents to use for inpainting/image-to-image denoising.", + ), + ] + + @staticmethod + def check_inputs(image_latents, latents): + if image_latents.shape[0] != latents.shape[0]: + raise ValueError( + f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + ) + + if image_latents.ndim != 3: + raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + image_latents=block_state.image_latents, + latents=block_state.latents, + ) + + # prepare latent timestep + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + + # make copy of initial_noise + block_state.initial_noise = block_state.latents + + # scale noise + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + """ + Step that creates mask latents from preprocessed mask_image by interpolating to latent space. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="processed_mask_image", + required=True, + type_hint=torch.Tensor, + description="The processed mask to use for the inpainting process.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process." + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + + height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + block_state.mask = torch.nn.functional.interpolate( + block_state.processed_mask_image, + size=(height_latents, width_latents), + ) + + block_state.mask = block_state.mask.unsqueeze(2) + block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1) + block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype) + + block_state.mask = components.pachifier.pack_latents(block_state.mask) + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 2. SET TIMESTEPS +# ==================== + + +# auto_docstring +class QwenImageSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The initial random noised latents for the denoising process. Can be generated in prepare latents step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + mu = calculate_shift( + image_seq_len=block_state.latents.shape[1], + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + """ + Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process." + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # Layered-specific mu calculation + base_seqlen = 256 * 256 / 16 / 16 # = 256 + mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5 + + # Default sigmas if not provided + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare + latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The latents to use for the denoising process. Can be generated in prepare latents step. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + num_inference_steps (`int`): + The number of denoising steps to perform at inference time. Updated based on strength. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to use for the denoising process. Can be generated in prepare latents step.", + ), + InputParam.template("strength", default=0.9), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="timesteps", + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + OutputParam( + name="num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time. Updated based on strength.", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + mu = calculate_shift( + image_seq_len=block_state.latents.shape[1], + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + device=device, + sigmas=sigmas, + mu=mu, + ) + + block_state.timesteps, block_state.num_inference_steps = get_timesteps( + scheduler=components.scheduler, + num_inference_steps=block_state.num_inference_steps, + strength=block_state.strength, + ) + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 3. OTHER INPUTS FOR DENOISER +# ==================== + +## RoPE inputs for denoiser + + +# auto_docstring +class QwenImageRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the images latents, used for RoPE calculation + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=list[list[tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + ] + ] * block_state.batch_size + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after + prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`int`): + The height of the reference image. Can be generated in input step. + image_width (`int`): + The width of the reference image. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the images latents, used for RoPE calculation + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=int, + description="The height of the reference image. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=int, + description="The width of the reference image. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=list[list[tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # for edit, image size can be different from the target size (height/width) + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ), + ( + 1, + block_state.image_height // components.vae_scale_factor // 2, + block_state.image_width // components.vae_scale_factor // 2, + ), + ] + ] * block_state.batch_size + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus. + Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed + after prepare_latents step. + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`list`): + The heights of the reference images. Can be generated in input step. + image_width (`list`): + The widths of the reference images. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`list`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n" + "Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n" + "Should be placed after prepare_latents step." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=list[int], + description="The heights of the reference images. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=list[int], + description="The widths of the reference images. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=list[list[tuple[int, int, int]]], + description="The shapes of the image latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + kwargs_type="denoiser_input_fields", + type_hint=list[int], + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + + # Edit Plus: image_height and image_width are lists + block_state.img_shapes = [ + [ + (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), + *[ + (1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2) + for img_height, img_width in zip(block_state.image_height, block_state.image_width) + ], + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`list`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`list`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`list`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + additional_t_cond (`Tensor`): + The additional t cond, used for RoPE calculation + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam.template("layers"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="img_shapes", + type_hint=list[list[tuple[int, int, int]]], + kwargs_type="denoiser_input_fields", + description="The shapes of the image latents, used for RoPE calculation", + ), + OutputParam( + name="txt_seq_lens", + type_hint=list[int], + kwargs_type="denoiser_input_fields", + description="The sequence lengths of the prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="negative_txt_seq_lens", + type_hint=list[int], + kwargs_type="denoiser_input_fields", + description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", + ), + OutputParam( + name="additional_t_cond", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The additional t cond, used for RoPE calculation", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # All shapes are the same for Layered + shape = ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + + # layers+1 output shapes + 1 condition shape (all same) + block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size + + # txt_seq_lens + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long) + + self.set_block_state(state, block_state) + return components, state + + +## ControlNet inputs for denoiser + + +# auto_docstring +class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + """ + step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step. + + Components: + controlnet (`QwenImageControlNetModel`) + + Inputs: + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + controlnet_keep (`list`): + The controlnet keep values + """ + + model_name = "qwenimage" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("controlnet", QwenImageControlNetModel), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("control_guidance_start"), + InputParam.template("control_guidance_end"), + InputParam.template("controlnet_conditioning_scale"), + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("controlnet_keep", type_hint=list[float], description="The controlnet keep values"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): + mult = ( + len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1 + ) + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps) + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/decoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ccb6b8e0470dbd9140e4f4927050801c5d629d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/decoders.py @@ -0,0 +1,511 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import InpaintProcessor, VaeImageProcessor +from ...models import AutoencoderKLQwenImage +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier + + +logger = logging.get_logger(__name__) + + +# after denoising loop (unpack latents) + + +# auto_docstring +class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + """ + Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, + channels, 1, height, width) + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + latents (`Tensor`): + The latents to decode, can be generated in the denoise step. + + Outputs: + latents (`Tensor`): + The denoisedlatents unpacked to B, C, 1, H, W + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + return components + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode, can be generated in the denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W" + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + + Outputs: + latents (`Tensor`): + Denoised latents. (unpacked to B, C, layers+1, H, W) + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("layers"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W) + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, + block_state.height, + block_state.width, + block_state.layers, + components.vae_scale_factor, + ) + + self.set_block_state(state, block_state) + return components, state + + +# decode step + + +# auto_docstring +class QwenImageDecoderStep(ModularPipelineBlocks): + """ + Step that decodes the latents to images + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that decodes the latents to images" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ] + + return components + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images", note="tensor output of the vae decoder.")] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular + if block_state.latents.ndim == 4: + block_state.latents = block_state.latents.unsqueeze(dim=1) + elif block_state.latents.ndim != 5: + raise ValueError( + f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step." + ) + block_state.latents = block_state.latents.to(components.vae.dtype) + + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(block_state.latents.device, block_state.latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(block_state.latents.device, block_state.latents.dtype) + block_state.latents = block_state.latents / latents_std + latents_mean + block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0] + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + """ + Decode unpacked latents (B, C, layers+1, H, W) into layer images. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Decode unpacked latents (B, C, layers+1, H, W) into layer images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + + # 1. VAE normalization + latents = latents.to(components.vae.dtype) + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W) + b, c, f, h, w = latents.shape + # 3. Remove first frame (composite), keep layers frames + latents = latents[:, :, 1:] + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + + # 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W) + image = components.vae.decode(latents, return_dict=False)[0] + image = image.squeeze(2) + + # 5. Postprocess - returns flat list of B*layers images + image = components.image_processor.postprocess(image, output_type=block_state.output_type) + + # 6. Chunk into list per batch item + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + block_state.images = images + + self.set_block_state(state, block_state) + return components, state + + +# postprocess the decoded images + + +# auto_docstring +class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "postprocess the generated image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @staticmethod + def check_inputs(output_type): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type) + + block_state.images = components.image_processor.postprocess( + image=block_state.images, + output_type=block_state.output_type, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image, optional apply the mask overally to the original image.. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "postprocess the generated image, optional apply the mask overally to the original image.." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + InputParam( + name="mask_overlay_kwargs", + type_hint=dict[str, Any], + description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @staticmethod + def check_inputs(output_type, mask_overlay_kwargs): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + if mask_overlay_kwargs and output_type != "pil": + raise ValueError("only support output_type 'pil' for mask overlay") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs) + + if block_state.mask_overlay_kwargs is None: + mask_overlay_kwargs = {} + else: + mask_overlay_kwargs = block_state.mask_overlay_kwargs + + block_state.images = components.image_mask_processor.postprocess( + image=block_state.images, + **mask_overlay_kwargs, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..de8ea05c5047cbe346e0ce61e59e2f6c5827d3f2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/denoise.py @@ -0,0 +1,943 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline + + +logger = logging.get_logger(__name__) + +# ==================== +# 1. LOOP STEPS (run at each denoising step) +# ==================== + + +# loop step:before denoiser +class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # one timestep + block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + block_state.latent_model_input = block_state.latents + return components, block_state + + +class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam.template("image_latents"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # one timestep + + block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1) + block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + return components, block_state + + +class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("controlnet", QwenImageControlNetModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that runs the controlnet before the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."), + InputParam( + name="controlnet_keep", + required=True, + type_hint=list[float], + description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int): + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [ + c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i]) + ] + else: + controlnet_cond_scale = block_state.controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # run controlnet for the guidance batch + controlnet_block_samples = components.controlnet( + hidden_states=block_state.latent_model_input, + controlnet_cond=block_state.control_image_latents, + conditioning_scale=block_state.cond_scale, + timestep=block_state.timestep / 1000, + img_shapes=block_state.img_shapes, + encoder_hidden_states=block_state.prompt_embeds, + encoder_hidden_states_mask=block_state.prompt_embeds_mask, + return_dict=False, + ) + + block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples + + return components, block_state + + +# loop step:denoiser +class QwenImageLoopDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", QwenImageTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), + InputParam( + "img_shapes", + required=True, + type_hint=list[tuple[int, int]], + description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "encoder_hidden_states_mask": ( + getattr(block_state, "prompt_embeds_mask", None), + getattr(block_state, "negative_prompt_embeds_mask", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + # YiYi TODO: add cache context + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + **block_state.additional_cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + + # apply guidance rescale + pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True) + block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm) + + return components, block_state + + +class QwenImageEditLoopDenoiser(ModularPipelineBlocks): + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", QwenImageTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), + InputParam( + "img_shapes", + required=True, + type_hint=list[tuple[int, int]], + description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "encoder_hidden_states_mask": ( + getattr(block_state, "prompt_embeds_mask", None), + getattr(block_state, "negative_prompt_embeds_mask", None), + ), + } + + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + # YiYi TODO: add cache context + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep / 1000, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + **block_state.additional_cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + + pred = guider_output.pred[:, : block_state.latents.size(1)] + pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)] + + # apply guidance rescale + pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True) + pred_norm = torch.norm(pred, dim=-1, keepdim=True) + block_state.noise_pred = pred * (pred_cond_norm / pred_norm) + + return components, block_state + + +# loop step:after denoiser +class QwenImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that updates the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("latents"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that updates the latents using mask and image_latents for inpainting. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `QwenImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + InputParam.template("image_latents"), + InputParam( + "initial_noise", + required=True, + type_hint=torch.Tensor, + description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("latents"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.scale_noise( + block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise + ) + + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents + + return components, block_state + + +# ==================== +# 2. DENOISE LOOP WRAPPER: define the denoising loop logic +# ==================== +class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "qwenimage" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + name="timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam.template("num_inference_steps", required=True), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + block_state.additional_cond_kwargs = {} + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps +# ==================== + + +# Qwen Image (text2image, image2image) + + +# auto_docstring +class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2image and image2image tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents.\n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2image and image2image tasks for QwenImage." + ) + + +# Qwen Image (inpainting) +# auto_docstring +class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks for QwenImage." + ) + + +# Qwen Image (text2image, image2image) with controlnet +# auto_docstring +class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2img/img2img tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`list`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopBeforeDenoiserControlNet, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopBeforeDenoiserControlNet`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports text2img/img2img tasks with controlnet for QwenImage." + ) + + +# Qwen Image (inpainting) with controlnet +# auto_docstring +class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`list`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageLoopBeforeDenoiser, + QwenImageLoopBeforeDenoiserControlNet, + QwenImageLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = [ + "before_denoiser", + "before_denoiser_controlnet", + "denoiser", + "after_denoiser", + "after_denoiser_inpaint", + ] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageLoopBeforeDenoiser`\n" + " - `QwenImageLoopBeforeDenoiserControlNet`\n" + " - `QwenImageLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks with controlnet for QwenImage." + ) + + +# Qwen Image Edit (image2image) +# auto_docstring +class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Edit." + ) + + +# Qwen Image Edit (inpainting) +# auto_docstring +class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + QwenImageLoopAfterDenoiserInpaint, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + " - `QwenImageLoopAfterDenoiserInpaint`\n" + "This block supports inpainting tasks for QwenImage Edit." + ) + + +# Qwen Image Layered (image2image) +# auto_docstring +class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Layered. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`list`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Layered." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/encoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..527267dc0d6eb10348d9d9a0a87ea2154a0c6261 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/encoders.py @@ -0,0 +1,1780 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Text and VAE encoder blocks for QwenImage pipelines. +""" + +import PIL +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist +from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions +from ...utils import logging +from ...utils.torch_utils import unwrap_module +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageModularPipeline +from .prompt_templates import ( + QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_EDIT_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_LAYERED_CAPTION_PROMPT_CN, + QWENIMAGE_LAYERED_CAPTION_PROMPT_EN, + QWENIMAGE_PROMPT_TEMPLATE, + QWENIMAGE_PROMPT_TEMPLATE_START_IDX, +) + + +logger = logging.get_logger(__name__) + + +def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + +def get_qwen_prompt_embeds( + text_encoder, + tokenizer, + prompt: str | list[str] = None, + prompt_template_encode: str = QWENIMAGE_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_PROMPT_TEMPLATE_START_IDX, + tokenizer_max_length: int = 1024, + device: torch.device | None = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + + split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +def get_qwen_prompt_embeds_edit( + text_encoder, + processor, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + device: torch.device | None = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + +def get_qwen_prompt_embeds_edit_plus( + text_encoder, + processor, + prompt: str | list[str] = None, + image: torch.Tensor | list[PIL.Image.Image, PIL.Image.Image] | None = None, + prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + device: torch.device | None = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_template_encode.format(i + 1) + elif image is not None: + base_img_prompt = img_template_encode.format(1) + else: + base_img_prompt = "" + + template = prompt_template_encode + + drop_idx = prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + return prompt_embeds, encoder_attention_mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image +def encode_vae_image( + image: torch.Tensor, + vae: AutoencoderKLQwenImage, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, + sample_mode: str = "argmax", +): + if not isinstance(image, torch.Tensor): + raise ValueError(f"Expected image to be a tensor, got {type(image)}.") + + # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(vae.config.latents_std) + .view(1, latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + +# ==================== +# 1. RESIZE +# ==================== +# In QwenImage pipelines, resize is a separate step because the resized image is used in VL encoding and vae encoder blocks: +# +# image (PIL.Image.Image) +# │ +# ▼ +# resized_image ([PIL.Image.Image]) +# │ +# ├──► text_encoder ──► prompt_embeds, prompt_embeds_mask +# │ (VL encoding needs the resized image for vision-language fusion) +# │ +# └──► image_processor ──► processed_image (torch.Tensor, pixel space) +# │ +# ▼ +# vae_encoder ──► image_latents (torch.Tensor, latent space) +# +# In most of our other pipelines, resizing is done as part of the image preprocessing step. +# ==================== + + +# auto_docstring +class QwenImageEditResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to target area while maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`list`): + The resized images + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to target area while maintaining the aspect ratio." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [InputParam.template("image")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="resized_image", + type_hint=list[PIL.Image.Image], + description="The resized images", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + image_width, image_height = images[0].size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + block_state.resized_image = resized_images + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while + maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + + Outputs: + resized_image (`list`): + The resized images + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while maintaining the aspect ratio." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image"), + InputParam( + name="resolution", + default=640, + type_hint=int, + description="The target area to resize the image to, can be 1024 or 640", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="resized_image", + type_hint=list[PIL.Image.Image], + description="The resized images", + ) + ] + + @staticmethod + def check_inputs(resolution: int): + if resolution not in [1024, 640]: + raise ValueError(f"Resolution must be 1024 or 640 but is {resolution}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(resolution=block_state.resolution) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + image_width, image_height = images[0].size + target_area = block_state.resolution * block_state.resolution + calculated_width, calculated_height, _ = calculate_dimensions(target_area, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + block_state.resized_image = resized_images + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusResizeStep(ModularPipelineBlocks): + """ + Resize images for QwenImage Edit Plus pipeline. + Produces two outputs: resized_image (1024x1024) for VAE encoding, resized_cond_image (384x384) for VL text + encoding. Each image is resized independently based on its own aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`list`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`list`): + Images resized to 384x384 target area for VL text encoding + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Resize images for QwenImage Edit Plus pipeline.\n" + "Produces two outputs: resized_image (1024x1024) for VAE encoding, " + "resized_cond_image (384x384) for VL text encoding.\n" + "Each image is resized independently based on its own aspect ratio." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + # image + return [InputParam.template("image")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="resized_image", + type_hint=list[PIL.Image.Image], + description="Images resized to 1024x1024 target area for VAE encoding", + ), + OutputParam( + name="resized_cond_image", + type_hint=list[PIL.Image.Image], + description="Images resized to 384x384 target area for VL text encoding", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + # Resize each image independently based on its own aspect ratio + resized_images = [] + resized_cond_images = [] + for image in images: + image_width, image_height = image.size + + # For VAE encoder (1024x1024 target area) + vae_width, vae_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + resized_images.append(components.image_resize_processor.resize(image, height=vae_height, width=vae_width)) + + # For VL text encoder (384x384 target area) + vl_width, vl_height, _ = calculate_dimensions(384 * 384, image_width / image_height) + resized_cond_images.append( + components.image_resize_processor.resize(image, height=vl_height, width=vl_width) + ) + + block_state.resized_image = resized_images + block_state.resized_cond_image = resized_cond_images + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 2. GET IMAGE PROMPT +# ==================== + + +# auto_docstring +class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): + """ + Auto-caption step that generates a text prompt from the input image if none is provided. + Uses the VL model (text_encoder) to generate a description of the image. If prompt is already provided, this step + passes through unchanged. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + resized_image (`Image`): + The image to generate caption from, should be resized use the resize step + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + + Outputs: + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + """ + + model_name = "qwenimage-layered" + + def __init__(self): + self.image_caption_prompt_en = QWENIMAGE_LAYERED_CAPTION_PROMPT_EN + self.image_caption_prompt_cn = QWENIMAGE_LAYERED_CAPTION_PROMPT_CN + super().__init__() + + @property + def description(self) -> str: + return ( + "Auto-caption step that generates a text prompt from the input image if none is provided.\n" + "Uses the VL model (text_encoder) to generate a description of the image.\n" + "If prompt is already provided, this step passes through unchanged." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template( + "prompt", required=False + ), # it is not required for qwenimage-layered, unlike other pipelines + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The image to generate caption from, should be resized use the resize step", + ), + InputParam( + name="use_en_prompt", + default=False, + type_hint=bool, + description="Whether to use English prompt template", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="prompt", + type_hint=str, + description="The prompt or prompts to guide image generation. If not provided, updated using image caption", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # If prompt is empty or None, generate caption from image + if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ": + if block_state.use_en_prompt: + caption_prompt = self.image_caption_prompt_en + else: + caption_prompt = self.image_caption_prompt_cn + + model_inputs = components.processor( + text=caption_prompt, + images=block_state.resized_image, + padding=True, + return_tensors="pt", + ).to(device) + + generated_ids = components.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = components.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + block_state.prompt = output_text.strip() + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 3. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that generates text embeddings to guide the image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_PROMPT_TEMPLATE_START_IDX + self.tokenizer_max_length = 1024 + super().__init__() + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the image generation." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"), + ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=1024), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt, max_sequence_length): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + device = components._execution_device + self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length) + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=block_state.prompt, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, + device=device, + ) + + block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length] + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length] + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or "" + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=negative_prompt, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, + device=device, + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[ + :, : block_state.max_sequence_length + ] + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[ + :, : block_state.max_sequence_length + ] + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image + generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_image (`Image`): + The image prompt to encode, should be resized using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX + super().__init__() + + @property + def description(self) -> str: + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The image prompt to encode, should be resized using resize step", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_image, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or " " + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_image, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together to generate text + embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_cond_image (`Tensor`): + The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using + resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit-plus" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE + self.img_template_encode = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX + super().__init__() + + @property + def description(self) -> str: + return ( + "Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together " + "to generate text embeddings for guiding image generation." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam( + name="resized_cond_image", + required=True, + type_hint=torch.Tensor, + description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_cond_image, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or " " + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( + get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_cond_image, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + device=device, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 4. IMAGE PREPROCESS +# ==================== + + +# auto_docstring +class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be + resized to the given height and width. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be resized to the given height and width." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("mask_image"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), + OutputParam( + name="mask_overlay_kwargs", + type_hint=dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=block_state.image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be + resized first. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + resized_image (`Image`): + The resized image. should be generated using a resize step + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be resized first." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("mask_image"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The resized image. should be generated using a resize step", + ), + InputParam.template("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="processed_image", type_hint=torch.Tensor, description="The processed image"), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), + OutputParam( + name="mask_overlay_kwargs", + type_hint=dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size + + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=block_state.resized_image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. will resize the image to the given height and width. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Image Preprocess step. will resize the image to the given height and width." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + block_state.processed_image = components.image_processor.preprocess( + image=block_state.image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images needs to be resized first. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`list`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step. Images needs to be resized first." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="resized_image", + required=True, + type_hint=list[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size + + block_state.processed_image = components.image_processor.preprocess( + image=block_state.resized_image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of + processed images. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`list`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return "Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="resized_image", + required=True, + type_hint=list[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ) + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + image = block_state.resized_image + + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] + + processed_images = [] + for img in image: + img_width, img_height = img.size + processed_images.append( + components.image_processor.preprocess(image=img, height=img_height, width=img_width) + ) + + if is_image_list: + block_state.processed_image = processed_images + else: + block_state.processed_image = processed_images[0] + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 5. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts processed_image into latent representations image_latents. + Handles both single images and lists of images with varied resolutions. + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + processed_image (`Tensor`): + The image tensor to encode + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + + def __init__(self, input: InputParam | None = None, output: OutputParam | None = None): + """Initialize a VAE encoder step for converting images to latent representations. + + Handles both single images and lists of images. When input is a list, outputs a list of latents. When input is + a single tensor, outputs a single latent tensor. + + Args: + input (InputParam, optional): Input parameter for the processed image. Defaults to "processed_image". + output (OutputParam, optional): Output parameter for the image latents. Defaults to "image_latents". + """ + if input is None: + input = InputParam( + name="processed_image", required=True, type_hint=torch.Tensor, description="The image tensor to encode" + ) + + if output is None: + output = OutputParam.template("image_latents") + + if not isinstance(input, InputParam): + raise ValueError(f"input must be InputParam but is {type(input)}") + if not isinstance(output, OutputParam): + raise ValueError(f"output must be OutputParam but is {type(output)}") + + self._input = input + self._output = output + self._image_input_name = input.name + self._image_latents_output_name = output.name + super().__init__() + + @property + def description(self) -> str: + return ( + f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + "Handles both single images and lists of images with varied resolutions." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLQwenImage)] + + @property + def inputs(self) -> list[InputParam]: + return [ + self._input, # default is "processed_image" + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [self._output] # default is "image_latents" + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.vae.dtype + + image = getattr(block_state, self._image_input_name) + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] + + # Handle both single image and list of images + image_latents = [] + for img in image: + image_latents.append( + encode_vae_image( + image=img, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + ) + if not is_image_list: + image_latents = image_latents[0] + + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts `control_image` into latent representations control_image_latents. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n" + + @property + def expected_components(self) -> list[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec("controlnet", QwenImageControlNetModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + return components + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("control_image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("generator"), + ] + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "control_image_latents", + type_hint=torch.Tensor, + description="The latents representing the control image", + ) + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor) + + device = components._execution_device + dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + controlnet = unwrap_module(components.controlnet) + if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + + if isinstance(controlnet, QwenImageMultiControlNetModel): + block_state.control_image_latents = [] + for control_image_ in block_state.control_image: + control_image_ = components.control_image_processor.preprocess( + image=control_image_, + height=height, + width=width, + ) + + control_image_latents_ = encode_vae_image( + image=control_image_, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + sample_mode="sample", + ) + block_state.control_image_latents.append(control_image_latents_) + + elif isinstance(controlnet, QwenImageControlNetModel): + control_image = components.control_image_processor.preprocess( + image=block_state.control_image, + height=height, + width=width, + ) + block_state.control_image_latents = encode_vae_image( + image=control_image, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + sample_mode="sample", + ) + + else: + raise ValueError( + f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}" + ) + + self.set_block_state(state, block_state) + + return components, state + + +# ==================== +# 6. PERMUTE LATENTS +# ==================== + + +# auto_docstring +class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): + """ + Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing. + + Inputs: + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. (permuted from [B, C, 1, H, W] to [B, 1, C, H, W]) + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("image_latents", note="permuted from [B, C, 1, H, W] to [B, 1, C, H, W]"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W) + latents = block_state.image_latents + block_state.image_latents = latents.permute(0, 2, 1, 3, 4) + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/inputs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..faec7db245df1ac1ab00245628c25abbf21e15f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/inputs.py @@ -0,0 +1,1024 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...models import QwenImageMultiControlNetModel +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent space dimensions to image space dimensions by multiplying the latent height and width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress images. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + # make sure the latents are not packed + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}") + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + +# auto_docstring +class QwenImageTextInputsStep(ModularPipelineBlocks): + """ + Text input processing step that standardizes text embeddings for the pipeline. + This step: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt) + + This block should be placed after all encoder steps to process the text embeddings before they are used in + subsequent pipeline steps. + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + summary_section = ( + "Text input processing step that standardizes text embeddings for the pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + # Placement guidance + placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps." + + return summary_section + placement_section + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_images_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"), + OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"), + OutputParam.template("prompt_embeds", note="batch-expanded"), + OutputParam.template("prompt_embeds_mask", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"), + ] + + @staticmethod + def check_inputs( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + ): + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None") + + if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None: + raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`") + + if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]: + raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + + elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]: + raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`") + + elif ( + negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0] + ): + raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`") + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + prompt_embeds=block_state.prompt_embeds, + prompt_embeds_mask=block_state.prompt_embeds_mask, + negative_prompt_embeds=block_state.negative_prompt_embeds, + negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask, + ) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len + ) + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step that: + 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + # by default, process `image_latents` + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), + ] + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), + ] + + # `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents and update if not provided + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step for Edit Plus that: + 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch + 2. For additional batch inputs: Expands batch dimensions to match final batch size + Height/width defaults to last image in the list. + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`list`): + The image heights calculated from the image latents dimension + image_width (`list`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + + model_name = "qwenimage-edit-plus" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Edit Plus that:\n" + " 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size\n" + " Height/width defaults to last image in the list." + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), + ] + + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=list[int], + description="The image heights calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=list[int], + description="The image widths calculated from the image latents dimension", + ), + ] + + # `height`/`width` are updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified, concatenated, and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified, concatenated, and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + is_list = isinstance(image_latent_tensor, list) + if not is_list: + image_latent_tensor = [image_latent_tensor] + + image_heights = [] + image_widths = [] + packed_image_latent_tensors = [] + + for i, img_latent_tensor in enumerate(image_latent_tensor): + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) + image_heights.append(height) + image_widths.append(width) + + # 2. Patchify + img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) + + # 3. Expand batch size + img_latent_tensor = repeat_tensor_to_batch_size( + input_name=f"{image_latent_input_name}[{i}]", + input_tensor=img_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + packed_image_latent_tensors.append(img_latent_tensor) + + # Concatenate all packed latents along dim=1 + packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + + # Output lists of heights/widths + block_state.image_height = image_heights + block_state.image_width = image_widths + + # Default height/width from last image + block_state.height = block_state.height or image_heights[-1] + block_state.width = block_state.width or image_widths[-1] + + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# same as QwenImageAdditionalInputsStep, but with layered pachifier. + + +# auto_docstring +class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step for Layered that: + 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch + size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + + model_name = "qwenimage-layered" + + def __init__( + self, + image_latent_inputs: list[InputParam] | None = None, + additional_batch_inputs: list[InputParam] | None = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Layered that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + ] + # default is `image_latents` + + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), + ] + + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified with layered pachifier and batch-expanded)", + ) + ) + + # Add outputs for additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents and update if not provided + # Layered latents are (B, layers, C, H, W) + height = image_latent_tensor.shape[3] * components.vae_scale_factor + width = image_latent_tensor.shape[4] * components.vae_scale_factor + block_state.height = height + block_state.width = width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify with layered pachifier + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageControlNetInputsStep(ModularPipelineBlocks): + """ + prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps. + + Inputs: + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + control_image_latents (`Tensor`): + The control image latents (patchified and batch-expanded). + height (`int`): + if not provided, updated to control image height + width (`int`): + if not provided, updated to control image width + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam.template("batch_size"), + InputParam.template("num_images_per_prompt"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + name="control_image_latents", + type_hint=torch.Tensor, + description="The control image latents (patchified and batch-expanded).", + ), + OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"), + OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if isinstance(components.controlnet, QwenImageMultiControlNetModel): + control_image_latents = [] + # loop through each control_image_latents + for i, control_image_latents_ in enumerate(block_state.control_image_latents): + # 1. update height/width if not provided + height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. pack + control_image_latents_ = components.pachifier.pack_latents(control_image_latents_) + + # 3. repeat to match the batch size + control_image_latents_ = repeat_tensor_to_batch_size( + input_name=f"control_image_latents[{i}]", + input_tensor=control_image_latents_, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + control_image_latents.append(control_image_latents_) + + block_state.control_image_latents = control_image_latents + + else: + # 1. update height/width if not provided + height, width = calculate_dimension_from_latents( + block_state.control_image_latents, components.vae_scale_factor + ) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 2. pack + block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents) + + # 3. repeat to match the batch size + block_state.control_image_latents = repeat_tensor_to_batch_size( + input_name="control_image_latents", + input_tensor=block_state.control_image_latents, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + block_state.control_image_latents = block_state.control_image_latents + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..bf87028b2f90841c9ba17357879b69b74cf8b150 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -0,0 +1,1224 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + QwenImageControlNetBeforeDenoiserStep, + QwenImageCreateMaskLatentsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageRoPEInputsStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageControlNetDenoiseStep, + QwenImageDenoiseStep, + QwenImageInpaintControlNetDenoiseStep, + QwenImageInpaintDenoiseStep, +) +from .encoders import ( + QwenImageControlNetVaeEncoderStep, + QwenImageInpaintProcessImagesInputStep, + QwenImageProcessImagesInputStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageControlNetInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): + """ + Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextEncoderStep()] + block_names = ["text_encoder"] + block_trigger_inputs = ["prompt"] + + @property + def description(self) -> str: + return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block." + " - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided." + " - if `prompt` is not provided, step will be skipped." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for inpainting tasks. It: + - Resizes the image to the target size, based on `height` and `width`. + - Processes and updates `image` and `mask_image`. + - Creates `image_latents`. + + Components: + image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for inpainting tasks. It:\n" + " - Resizes the image to the target size, based on `height` and `width`.\n" + " - Processes and updates `image` and `mask_image`.\n" + " - Creates `image_latents`." + ) + + +# auto_docstring +class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + + block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" + + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# optional controlnet vae encoder +# auto_docstring +class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block. + - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided. + - if `control_image` is not provided, step will be skipped. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + + block_classes = [QwenImageControlNetVaeEncoderStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" + + " - if `control_image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# auto_docstring +class QwenImageInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the inpainting denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the inpainting denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# assemble prepare latents steps +# auto_docstring +class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the pachified latents `mask` based on the processedmask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the pachified latents `mask` based on the processedmask image.\n" + ) + + +# assemble denoising steps + + +# Qwen Image (text2image) +# auto_docstring +class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (inpainting) +# auto_docstring +class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (image2image) +# auto_docstring +class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (text2image) with controlnet +# auto_docstring +class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (inpainting) with controlnet +# auto_docstring +class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageInpaintControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (image2image) with controlnet +# auto_docstring +class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto denoise step for QwenImage +class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): + block_classes = [ + QwenImageCoreDenoiseStep, + QwenImageInpaintCoreDenoiseStep, + QwenImageImg2ImgCoreDenoiseStep, + QwenImageControlNetCoreDenoiseStep, + QwenImageControlNetInpaintCoreDenoiseStep, + QwenImageControlNetImg2ImgCoreDenoiseStep, + ] + block_names = [ + "text2image", + "inpaint", + "img2img", + "controlnet_text2image", + "controlnet_inpaint", + "controlnet_img2img", + ] + block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"] + default_block_name = "text2image" + + def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None): + if control_image_latents is not None: + if processed_mask_image is not None: + return "controlnet_inpaint" + elif image_latents is not None: + return "controlnet_img2img" + else: + return "controlnet_text2image" + else: + if processed_mask_image is not None: + return "inpaint" + elif image_latents is not None: + return "img2img" + else: + return "text2image" + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n" + + " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n" + + " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n" + + " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n" + + " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n" + + " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n" + + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" + + " - for image-to-image generation, you need to provide `image_latents`\n" + + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" + + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings" + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# standard decode step works for most tasks except for inpaint +# auto_docstring +class QwenImageDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +# auto_docstring +class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask + overally to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." + + +# Auto decode step for QwenImage +class QwenImageAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images. \n" + " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" + + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" + ) + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageAutoTextEncoderStep()), + ("vae_encoder", QwenImageAutoVaeEncoderStep()), + ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), + ("denoise", QwenImageAutoCoreDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `prompt`, `image` + - `inpainting`: requires `prompt`, `mask_image`, `image` + - `controlnet_text2image`: requires `prompt`, `control_image` + - `controlnet_image2image`: requires `prompt`, `image`, `control_image` + - `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image` + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`) + control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + mask_image (`Image`, *optional*): + Mask image for inpainting. + image (`Image | list`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_image_latents (`Tensor`, *optional*): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + # Workflow map defines the trigger conditions for each workflow. + # How to define: + # - Only include required inputs and trigger inputs (inputs that determine which blocks run) + # - currently, only supports `True` means the workflow triggers when the input is not None + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"prompt": True, "image": True}, + "inpainting": {"prompt": True, "mask_image": True, "image": True}, + "controlnet_text2image": {"prompt": True, "control_image": True}, + "controlnet_image2image": {"prompt": True, "image": True, "control_image": True}, + "controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..37b80b69ec7ec5f8d027f59bfabafbda07772dc4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -0,0 +1,796 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + QwenImageCreateMaskLatentsStep, + QwenImageEditRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, + QwenImageEditInpaintDenoiseStep, +) +from .encoders import ( + QwenImageEditInpaintProcessImagesInputStep, + QwenImageEditProcessImagesInputStep, + QwenImageEditResizeStep, + QwenImageEditTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Edit VL encoder step that encode the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`list`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit VL encoder step that encode the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +# auto_docstring +class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# Edit Inpaint VAE encoder +# auto_docstring +class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It: + - resize the image for target area (1024 * 1024) while maintaining the aspect ratio. + - process the resized image and mask image. + - create image latents. + + Components: + image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + mask_image (`Image`): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + The resized images + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditInpaintProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" + " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" + " - process the resized image and mask image.\n" + " - create image latents." + ) + + +# Auto VAE encoder +class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageEditInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# auto_docstring +class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit inpaint denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit inpaint denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# assemble prepare latents steps +# auto_docstring +class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the patchified latents `mask` based on the processed mask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the patchified latents `mask` based on the processed mask image.\n" + ) + + +# Qwen Image Edit (image2image) core denoise step +# auto_docstring +class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image Edit (inpainting) core denoise step +# auto_docstring +class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit inpaint task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageEditInpaintPrepareLatentsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto core denoise step for QwenImage Edit +class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintCoreDenoiseStep, + QwenImageEditCoreDenoiseStep, + ] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["processed_mask_image", "image_latents"] + default_block_name = "edit" + + def select_block(self, processed_mask_image=None, image_latents=None) -> str | None: + if processed_mask_image is not None: + return "edit_inpaint" + elif image_latents is not None: + return "edit" + return None + + @property + def description(self): + return ( + "Auto core denoising step that selects the appropriate workflow based on inputs.\n" + " - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n" + " - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n" + "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# Decode step (standard) +# auto_docstring +class QwenImageEditDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +# auto_docstring +class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask + overlay to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image." + + +# Auto decode step +class QwenImageEditAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + " - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n" + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), + ("denoise", QwenImageEditAutoCoreDenoiseStep()), + ("decode", QwenImageEditAutoDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit. + - for edit (img2img) generation, you need to provide `image` + - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide + `padding_mask_crop` + + + Supported workflows: + - `image_conditioned`: requires `prompt`, `image` + - `image_conditioned_inpainting`: requires `prompt`, `mask_image`, `image` + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + mask_image (`Image`, *optional*): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-edit" + block_classes = EDIT_AUTO_BLOCKS.values() + block_names = EDIT_AUTO_BLOCKS.keys() + _workflow_map = { + "image_conditioned": {"prompt": True, "image": True}, + "image_conditioned_inpainting": {"prompt": True, "mask_image": True, "image": True}, + } + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" + "- for edit (img2img) generation, you need to provide `image`\n" + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n" + ) + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1f418d7b4508c2b5702c7d04c3767174287ab4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -0,0 +1,407 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + QwenImageEditPlusRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImageSetTimestepsStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, +) +from .encoders import ( + QwenImageEditPlusProcessImagesInputStep, + QwenImageEditPlusResizeStep, + QwenImageEditPlusTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageEditPlusAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`list`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`list`): + Images resized to 384x384 target area for VL text encoding + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(), + QwenImageEditPlusTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): + """ + VAE encoder step that encodes image inputs into latent representations. + Each image is resized independently based on its own aspect ratio to 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`list`): + Images resized to 384x384 target area for VL text encoding + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(), + QwenImageEditPlusProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "VAE encoder step that encodes image inputs into latent representations.\n" + "Each image is resized independently based on its own aspect ratio to 1024x1024 target area." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the Edit Plus denoising step. It: + - Standardizes text embeddings batch size. + - Processes list of image latents: patchifies, concatenates along dim=1, expands batch. + - Outputs lists of image_height/image_width for RoPE calculation. + - Defaults height/width from last image in the list. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`list`): + The image heights calculated from the image latents dimension + image_width (`list`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageEditPlusAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the Edit Plus denoising step. It:\n" + " - Standardizes text embeddings batch size.\n" + " - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n" + " - Outputs lists of image_height/image_width for RoPE calculation.\n" + " - Defaults height/width from last image in the list." + ) + + +# Qwen Image Edit Plus (image2image) core denoise step +# auto_docstring +class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditPlusRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# auto_docstring +class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocesses the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit-plus" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocesses the generated image." + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_PLUS_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), + ("denoise", QwenImageEditPlusCoreDenoiseStep()), + ("decode", QwenImageEditPlusDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus. + - `image` is required input (can be single image or list of images). + - Each image is resized independently based on its own aspect ratio. + - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) + transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-edit-plus" + block_classes = EDIT_PLUS_AUTO_BLOCKS.values() + block_names = EDIT_PLUS_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n" + "- `image` is required input (can be single image or list of images).\n" + "- Each image is resized independently based on its own aspect ratio.\n" + "- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area." + ) + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py new file mode 100644 index 0000000000000000000000000000000000000000..a10454f1fb0c9ec20a2624ce55c933efdf069214 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -0,0 +1,366 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + QwenImageLayeredPrepareLatentsStep, + QwenImageLayeredRoPEInputsStep, + QwenImageLayeredSetTimestepsStep, +) +from .decoders import ( + QwenImageLayeredAfterDenoiseStep, + QwenImageLayeredDecoderStep, +) +from .denoise import ( + QwenImageLayeredDenoiseStep, +) +from .encoders import ( + QwenImageEditProcessImagesInputStep, + QwenImageLayeredGetImagePromptStep, + QwenImageLayeredPermuteLatentsStep, + QwenImageLayeredResizeStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageLayeredAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not + provided. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + resized_image (`list`): + The resized images + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageLayeredGetImagePromptStep(), + QwenImageTextEncoderStep(), + ] + block_names = ["resize", "get_image_prompt", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +# auto_docstring +class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`list`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + QwenImageLayeredPermuteLatentsStep(), + ] + block_names = ["resize", "preprocess", "encode", "permute"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageLayeredInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the layered denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageLayeredAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the layered denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# Qwen Image Layered (image2image) core denoise step +# auto_docstring +class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Layered img2img task. + + Components: + pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredInputStep(), + QwenImageLayeredPrepareLatentsStep(), + QwenImageLayeredSetTimestepsStep(), + QwenImageLayeredRoPEInputsStep(), + QwenImageLayeredDenoiseStep(), + QwenImageLayeredAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Layered img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. AUTO BLOCKS & PRESETS +# ==================== + +LAYERED_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageLayeredTextEncoderStep()), + ("vae_encoder", QwenImageLayeredVaeEncoderStep()), + ("denoise", QwenImageLayeredCoreDenoiseStep()), + ("decode", QwenImageLayeredDecoderStep()), + ] +) + + +# auto_docstring +class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`) + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "qwenimage-layered" + block_classes = LAYERED_AUTO_BLOCKS.values() + block_names = LAYERED_AUTO_BLOCKS.keys() + + @property + def description(self): + return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..892435989d00fc07526f3d2c5eec6efc4380a2fa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -0,0 +1,297 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import QwenImageLoraLoaderMixin +from ..modular_pipeline import ModularPipeline + + +class QwenImagePachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage. + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents): + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}") + + if latents.ndim == 4: + latents = latents.unsqueeze(2) + + batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute( + 0, 2, 4, 1, 3, 5 + ) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size + latents = latents.reshape( + batch_size, + (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + + return latents + + def unpack_latents(self, latents, height, width, vae_scale_factor=8): + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, num_patches, channels = latents.shape + patch_size = self.config.patch_size + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width) + + return latents + + +class QwenImageLayeredPachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage Layered. + + Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W). + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents): + """ + Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4). + """ + + if latents.ndim != 5: + raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}") + + batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + layers, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape( + batch_size, + layers * (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + return latents + + def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W). + """ + + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, _, channels = latents.shape + patch_size = self.config.patch_size + + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + layers + 1, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + latents = latents.reshape( + batch_size, + layers + 1, + channels // (patch_size * patch_size), + height, + width, + ) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + +class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): + """ + A ModularPipeline for QwenImage. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def is_guidance_distilled(self): + is_guidance_distilled = False + if hasattr(self, "transformer") and self.transformer is not None: + is_guidance_distilled = self.transformer.config.guidance_embeds + return is_guidance_distilled + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): + """ + A ModularPipeline for QwenImage-Edit. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageEditAutoBlocks" + + # YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step. + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def is_guidance_distilled(self): + is_guidance_distilled = False + if hasattr(self, "transformer") and self.transformer is not None: + is_guidance_distilled = self.transformer.config.guidance_embeds + return is_guidance_distilled + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + +class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline): + """ + A ModularPipeline for QwenImage-Edit Plus. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageEditPlusAutoBlocks" + + +class QwenImageLayeredModularPipeline(QwenImageModularPipeline): + """ + A ModularPipeline for QwenImage-Layered. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageLayeredAutoBlocks" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/prompt_templates.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/prompt_templates.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7beb555760a5f987342e406f05e0697a9c7bc3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/qwenimage/prompt_templates.py @@ -0,0 +1,121 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt templates for QwenImage pipelines. + +This module centralizes all prompt templates used across different QwenImage pipeline variants: +- QwenImage (base): Text-only encoding for text-to-image generation +- QwenImage Edit: VL encoding with single image for image editing +- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing +- QwenImage Layered: Auto-captioning for image decomposition +""" + +# ============================================ +# QwenImage Base (text-only encoding) +# ============================================ +# Used for text-to-image generation where only text prompt is encoded + +QWENIMAGE_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the image by detailing the color, shape, size, texture, quantity, text, " + "spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34 + + +# ============================================ +# QwenImage Edit (VL encoding with single image) +# ============================================ +# Used for single-image editing where both image and text are encoded together + +QWENIMAGE_EDIT_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Edit Plus (VL encoding with multiple images) +# ============================================ +# Used for multi-reference editing where multiple images and text are encoded together +# The img_template is used to format each image in the prompt + +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Layered (auto-captioning) +# ============================================ +# Used for image decomposition where the VL model generates a caption from the input image +# if no prompt is provided. These prompts instruct the model to describe the image in detail. + +QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# Image Annotator\n" + "You are a professional image annotator. Please write an image caption based on the input image:\n" + "1. Write the caption using natural, descriptive language without structured formats or rich text.\n" + "2. Enrich caption details by including:\n" + " - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n" + " - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, " + "attachment relations, action relations, comparative relations, causal relations, and so on\n" + " - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n" + " - Identify the text clearly visible in the image, without translation or explanation, " + "and highlight it in the caption with quotation marks\n" + "3. Maintain authenticity and accuracy:\n" + " - Avoid generalizations\n" + " - Describe all visible information in the image, while do not add information not explicitly shown in the image\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) + +QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# 图像标注器\n" + "你是一个专业的图像标注器。请基于输入图像,撰写图注:\n" + "1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n" + "2. 通过加入以下内容,丰富图注细节:\n" + " - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n" + " - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n" + " - 环境细节:例如天气、光照、颜色、纹理、气氛等\n" + " - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n" + "3. 保持真实性与准确性:\n" + " - 不要使用笼统的描述\n" + " - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44f1c555cef3586ac9c05d53d45773040e3e77b1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks"] + _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_stable_diffusion_xl import StableDiffusionXLAutoBlocks + from .modular_pipeline import StableDiffusionXLModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3417d42884a2e4f835cf0eef0f1846b76a420a7f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/before_denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/before_denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37cbc4a7059b61d4064a77da3d534481dcc9f0a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/before_denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/decoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/decoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b7fd85b340410442b2b76ad54c84ac559aef548 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/decoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b68ffe6f71ed5cfdaa12ea73ee7edc93b5345797 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/encoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/encoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a3508e1f596665fed990d380dbd219569d3001a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/encoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/modular_blocks_stable_diffusion_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/modular_blocks_stable_diffusion_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..255811b1f9316079b16c96efe06e95976e35229b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/modular_blocks_stable_diffusion_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/modular_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/modular_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78af48cdee6faf592fba63adc85c3c5c9d2d4c24 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/__pycache__/modular_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..4a393e7ce296f13d0cb470535ff9a354dffcbdd6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -0,0 +1,1874 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any + +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel +from ...models.controlnets.multicontrolnet import MultiControlNetModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor, unwrap_module +from ..modular_pipeline import ( + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_latents_img2img( + vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True +): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}") + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + latents_mean = latents_std = None + if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None: + latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None: + latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1) + # make sure the VAE is in float32 mode, as it overflows in float16 + if vae.config.force_upcast: + image = image.float() + vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(vae.encode(image), generator=generator) + + if vae.config.force_upcast: + vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std + else: + init_latents = vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + +class StableDiffusionXLInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_pooled_prompt_embeds", + description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + InputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="negative pooled text embeddings used to guide the image generation", + ), + OutputParam( + "ip_adapter_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="image embeddings for IP-Adapter", + ), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields + description="negative image embeddings for IP-Adapter", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list): + raise ValueError("`ip_adapter_embeds` must be a list") + + if block_state.negative_ip_adapter_embeds is not None and not isinstance( + block_state.negative_ip_adapter_embeds, list + ): + raise ValueError("`negative_ip_adapter_embeds` must be a list") + + if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape: + raise ValueError( + "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but" + f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`" + f" {block_state.negative_ip_adapter_embeds[i].shape}." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if block_state.negative_pooled_prompt_embeds is not None: + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) + + if block_state.ip_adapter_embeds is not None: + for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds): + block_state.ip_adapter_embeds[i] = torch.cat( + [ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) + + if block_state.negative_ip_adapter_embeds is not None: + for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds): + block_state.negative_ip_adapter_embeds[i] = torch.cat( + [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n" + + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + InputParam("strength", default=0.3), + InputParam("denoising_start"), + # YiYi TODO: do we need num_images_per_prompt here? + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + OutputParam( + "latent_timestep", + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image generation", + ), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components + def get_timesteps(components, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (denoising_start * components.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if components.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(components.scheduler.timesteps) - num_inference_steps + timesteps = components.scheduler.timesteps[t_start:] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + block_state.timesteps, block_state.num_inference_steps = self.get_timesteps( + components, + block_state.num_inference_steps, + block_state.strength, + block_state.device, + denoising_start=block_state.denoising_start + if denoising_value_valid(block_state.denoising_start) + else None, + ) + block_state.latent_timestep = block_state.timesteps[:1].repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("denoising_end"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + if ( + block_state.denoising_end is not None + and isinstance(block_state.denoising_end, float) + and block_state.denoising_end > 0 + and block_state.denoising_end < 1 + ): + block_state.discrete_timestep_cutoff = int( + round( + components.scheduler.config.num_train_timesteps + - (block_state.denoising_end * components.scheduler.config.num_train_timesteps) + ) + ) + block_state.num_inference_steps = len( + list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps)) + ) + block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that prepares the latents for the inpainting process" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam( + "strength", + default=0.9999, + description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` " + "will be used as a starting point, adding more noise to it the larger the `strength`. The number of " + "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will " + "be maximum and the denoising process will run for the full number of iterations specified in " + "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " + "`denoising_start` being declared as an integer, the value of `strength` will be ignored.", + ), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "mask", + required=True, + type_hint=torch.Tensor, + description="The mask for the inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam( + "noise", + type_hint=torch.Tensor, + description="The noise added to the image latents, used for inpainting generation", + ), + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + @staticmethod + def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument + def prepare_latents_inpaint( + self, + components, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif latents is None and not is_strength_max: + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(components, image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * components.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents, noise, image_latents) + + return outputs + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + block_state.is_strength_max = block_state.strength == 1.0 + + # for non-inpainting specific unet, we do not need masked_image_latents + if hasattr(components, "unet") and components.unet is not None: + if components.unet.config.in_channels == 4: + block_state.masked_image_latents = None + + block_state.add_noise = True if block_state.denoising_start is None else False + + block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor + block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor + + block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( + components, + block_state.batch_size * block_state.num_images_per_prompt, + components.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + image=block_state.image_latents, + timestep=block_state.latent_timestep, + is_strength_max=block_state.is_strength_max, + add_noise=block_state.add_noise, + ) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + block_state.mask, + block_state.masked_image_latents, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that prepares the latents for the image-to-image generation process" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("denoising_start"), + InputParam("generator"), + InputParam( + "latent_timestep", + required=True, + type_hint=torch.Tensor, + description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + block_state.add_noise = True if block_state.denoising_start is None else False + if block_state.latents is None: + block_state.latents = prepare_latents_img2img( + components.vae, + components.scheduler, + block_state.image_latents, + block_state.latent_timestep, + block_state.batch_size, + block_state.num_images_per_prompt, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.add_noise, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("vae", AutoencoderKL), + ] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-image generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("latents"), + InputParam("num_images_per_prompt", default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if ( + block_state.height is not None + and block_state.height % components.vae_scale_factor != 0 + or block_state.width is not None + and block_state.width % components.vae_scale_factor != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->comp + def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // comp.vae_scale_factor, + int(width) // comp.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * comp.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.dtype is None: + block_state.dtype = components.vae.dtype + + block_state.device = components._execution_device + + self.check_inputs(components, block_state) + + block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor + block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor + block_state.num_channels_latents = components.num_channels_latents + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_images_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec("requires_aesthetics_score", False), + ] + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam("aesthetic_score", default=6.0), + InputParam("negative_aesthetic_score", default=2.0), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + components, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if components.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.vae_scale_factor = components.vae_scale_factor + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * block_state.vae_scale_factor + block_state.width = block_state.width * block_state.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + if block_state.negative_original_size is None: + block_state.negative_original_size = block_state.original_size + if block_state.negative_target_size is None: + block_state.negative_target_size = block_state.target_size + + block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.aesthetic_score, + block_state.negative_aesthetic_score, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + dtype=block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Step that prepares the additional conditioning for the text-to-image generation process" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("original_size"), + InputParam("target_size"), + InputParam("negative_original_size"), + InputParam("negative_target_size"), + InputParam("crops_coords_top_left", default=(0, 0)), + InputParam("negative_crops_coords_top_left", default=(0, 0)), + InputParam("num_images_per_prompt", default=1), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The time ids to condition the denoising process", + ), + OutputParam( + "negative_add_time_ids", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The negative time ids to condition the denoising process", + ), + OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components + def _get_add_time_ids( + components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + block_state.original_size = block_state.original_size or (block_state.height, block_state.width) + block_state.target_size = block_state.target_size or (block_state.height, block_state.width) + + block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1]) + + block_state.add_time_ids = self._get_add_time_ids( + components, + block_state.original_size, + block_state.crops_coords_top_left, + block_state.target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + if block_state.negative_original_size is not None and block_state.negative_target_size is not None: + block_state.negative_add_time_ids = self._get_add_time_ids( + components, + block_state.negative_original_size, + block_state.negative_crops_coords_top_left, + block_state.negative_target_size, + block_state.pooled_prompt_embeds.dtype, + text_encoder_projection_dim=block_state.text_encoder_projection_dim, + ) + else: + block_state.negative_add_time_ids = block_state.add_time_ids + + block_state.add_time_ids = block_state.add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat( + block_state.batch_size * block_state.num_images_per_prompt, 1 + ).to(device=block_state.device) + + # Optionally get Guidance Scale Embedding for LCM + block_state.timestep_cond = None + if ( + hasattr(components, "unet") + and components.unet is not None + and components.unet.config.time_cond_proj_dim is not None + ): + # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this! + block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat( + block_state.batch_size * block_state.num_images_per_prompt + ) + block_state.timestep_cond = self.get_guidance_scale_embedding( + block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim + ).to(device=block_state.device, dtype=block_state.latents.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "step that prepare inputs for controlnet" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "crops_coords", + type_hint=tuple[int] | None, + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"), + OutputParam( + "control_guidance_start", type_hint=list[float], description="The controlnet guidance start values" + ), + OutputParam( + "control_guidance_end", type_hint=list[float], description="The controlnet guidance end values" + ), + OutputParam( + "conditioning_scale", type_hint=list[float], description="The controlnet conditioning scale values" + ), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=list[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # (1) prepare controlnet inputs + block_state.device = components._execution_device + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + controlnet = unwrap_module(components.controlnet) + + # (1.1) + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + elif not isinstance(block_state.control_guidance_start, list) and not isinstance( + block_state.control_guidance_end, list + ): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + block_state.control_guidance_start, block_state.control_guidance_end = ( + mult * [block_state.control_guidance_start], + mult * [block_state.control_guidance_end], + ) + + # (1.2) + # controlnet_conditioning_scale (align format) + if isinstance(controlnet, MultiControlNetModel) and isinstance( + block_state.controlnet_conditioning_scale, float + ): + block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len( + controlnet.nets + ) + + # (1.3) + # global_pool_conditions + block_state.global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + # (1.4) + # guess_mode + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # (1.5) + # control_image + if isinstance(controlnet, ControlNetModel): + block_state.control_image = self.prepare_control_image( + components, + image=block_state.control_image, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in block_state.control_image: + control_image = self.prepare_control_image( + components, + image=control_image_, + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=block_state.device, + dtype=controlnet.dtype, + crops_coords=block_state.crops_coords, + ) + + control_images.append(control_image) + + block_state.control_image = control_images + else: + assert False + + # (1.6) + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + keeps = [ + 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e) + for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end) + ] + block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("controlnet", ControlNetUnionModel), + ComponentSpec( + "control_image_processor", + VaeImageProcessor, + config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "step that prepares inputs for the ControlNetUnion model" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("control_image", required=True), + InputParam("control_mode", required=True), + InputParam("control_guidance_start", default=0.0), + InputParam("control_guidance_end", default=1.0), + InputParam("controlnet_conditioning_scale", default=1.0), + InputParam("guess_mode", default=False), + InputParam("num_images_per_prompt", default=1), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of model tensor inputs. Can be generated in input step.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step.", + ), + InputParam( + "crops_coords", + type_hint=tuple[int] | None, + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("controlnet_cond", type_hint=list[torch.Tensor], description="The processed control images"), + OutputParam( + "control_type_idx", + type_hint=list[int], + description="The control mode indices", + kwargs_type="controlnet_kwargs", + ), + OutputParam( + "control_type", + type_hint=torch.Tensor, + description="The control type tensor that specifies which control type is active", + kwargs_type="controlnet_kwargs", + ), + OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"), + OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"), + OutputParam( + "conditioning_scale", type_hint=list[float], description="The controlnet conditioning scale values" + ), + OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"), + OutputParam("controlnet_keep", type_hint=list[float], description="The controlnet keep values"), + ] + + # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + # 1. return image without apply any guidance + # 2. add crops_coords and resize_mode to preprocess() + @staticmethod + def prepare_control_image( + components, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords=None, + ): + if crops_coords is not None: + image = components.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill" + ).to(dtype=torch.float32) + else: + image = components.control_image_processor.preprocess(image, height=height, width=width).to( + dtype=torch.float32 + ) + + image_batch_size = image.shape[0] + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) + return image + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + controlnet = unwrap_module(components.controlnet) + + device = components._execution_device + dtype = block_state.dtype or components.controlnet.dtype + + block_state.height, block_state.width = block_state.latents.shape[-2:] + block_state.height = block_state.height * components.vae_scale_factor + block_state.width = block_state.width * components.vae_scale_factor + + # control_guidance_start/control_guidance_end (align format) + if not isinstance(block_state.control_guidance_start, list) and isinstance( + block_state.control_guidance_end, list + ): + block_state.control_guidance_start = len(block_state.control_guidance_end) * [ + block_state.control_guidance_start + ] + elif not isinstance(block_state.control_guidance_end, list) and isinstance( + block_state.control_guidance_start, list + ): + block_state.control_guidance_end = len(block_state.control_guidance_start) * [ + block_state.control_guidance_end + ] + + # guess_mode + block_state.global_pool_conditions = controlnet.config.global_pool_conditions + block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions + + # control_image + if not isinstance(block_state.control_image, list): + block_state.control_image = [block_state.control_image] + # control_mode + if not isinstance(block_state.control_mode, list): + block_state.control_mode = [block_state.control_mode] + + if len(block_state.control_image) != len(block_state.control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + # control_type + block_state.num_control_type = controlnet.config.num_control_type + block_state.control_type = [0 for _ in range(block_state.num_control_type)] + for control_idx in block_state.control_mode: + block_state.control_type[control_idx] = 1 + block_state.control_type = torch.Tensor(block_state.control_type) + + block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype) + repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0] + block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0) + + # prepare control_image + for idx, _ in enumerate(block_state.control_image): + block_state.control_image[idx] = self.prepare_control_image( + components, + image=block_state.control_image[idx], + width=block_state.width, + height=block_state.height, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_images_per_prompt=block_state.num_images_per_prompt, + device=device, + dtype=dtype, + crops_coords=block_state.crops_coords, + ) + block_state.height, block_state.width = block_state.control_image[idx].shape[-2:] + + # controlnet_keep + block_state.controlnet_keep = [] + for i in range(len(block_state.timesteps)): + block_state.controlnet_keep.append( + 1.0 + - float( + i / len(block_state.timesteps) < block_state.control_guidance_start + or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end + ) + ) + block_state.control_type_idx = block_state.control_mode + block_state.controlnet_cond = block_state.control_image + block_state.conditioning_scale = block_state.controlnet_conditioning_scale + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..7e505559f68569f9c36c4971f76c9513dad016ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -0,0 +1,198 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import deprecate, logging +from ..modular_pipeline import ( + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionXLDecodeStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components + def upcast_vae(components): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + components.vae.to(dtype=torch.float32) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if not block_state.output_type == "latent": + latents = block_state.latents + # make sure the VAE is in float32 mode, as it overflows in float16 + block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast + + if block_state.needs_upcasting: + self.upcast_vae(components) + latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != components.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + components.vae = components.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + block_state.has_latents_mean = ( + hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None + ) + block_state.has_latents_std = ( + hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None + ) + if block_state.has_latents_mean and block_state.has_latents_std: + block_state.latents_mean = ( + torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + block_state.latents_std = ( + torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = ( + latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean + ) + else: + latents = latents / components.vae.config.scaling_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if block_state.needs_upcasting: + components.vae.to(dtype=torch.float16) + else: + block_state.images = block_state.latents + + # apply watermark if available + if hasattr(components, "watermark") and components.watermark is not None: + block_state.images = components.watermark.apply_watermark(block_state.images) + + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "A post-processing step that overlays the mask on the image (inpainting task only).\n" + + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("image"), + InputParam("mask_image"), + InputParam("padding_mask_crop"), + InputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + description="The generated images from the decode step", + ), + InputParam( + "crops_coords", + type_hint=tuple[int, int], + description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.", + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.padding_mask_crop is not None and block_state.crops_coords is not None: + block_state.images = [ + components.image_processor.apply_overlay( + block_state.mask_image, block_state.image, i, block_state.crops_coords + ) + for i in block_state.images + ] + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..0190bc3ea62fa4842ffd0c944dcf61dc5ef64d9d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -0,0 +1,798 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ControlNetModel, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi experimenting composible denoise loop +# loop step (1): prepare latent input for denoiser +class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepare the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + + return components, block_state + + +# loop step (1): prepare latent input for denoiser (with inpainting) +class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object" + ) + + @property + def inputs(self) -> list[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "mask", + type_hint=torch.Tensor | None, + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + InputParam( + "masked_image_latents", + type_hint=torch.Tensor | None, + description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + num_channels_unet = components.num_channels_unet + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + if block_state.mask is None or block_state.masked_image_latents is None: + raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet") + num_channels_latents = block_state.latents.shape[1] + num_channels_mask = block_state.mask.shape[1] + num_channels_masked_image = block_state.masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects" + f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `components.unet` or your `mask_image` or `image` input." + ) + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + self.check_inputs(components, block_state) + + block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) + if components.num_channels_unet == 9: + block_state.scaled_latents = torch.cat( + [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1 + ) + + return components, block_state + + +# loop step (2): denoise the latents with guidance +class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "timestep_cond", + type_hint=torch.Tensor | None, + description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @torch.no_grad() + def __call__( + self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int + ) -> PipelineState: + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_inputs = { + "prompt_embeds": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "time_ids": ( + getattr(block_state, "add_time_ids", None), + getattr(block_state, "negative_add_time_ids", None), + ), + "text_embeds": ( + getattr(block_state, "pooled_prompt_embeds", None), + getattr(block_state, "negative_pooled_prompt_embeds", None), + ), + "image_embeds": ( + getattr(block_state, "ip_adapter_embeds", None), + getattr(block_state, "negative_ip_adapter_embeds", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs(guider_inputs) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=cond_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +# loop step (2): denoise the latents with guidance (with controlnet) +class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec("controlnet", ControlNetModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that denoise the latents with guidance (with controlnet). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("cross_attention_kwargs"), + InputParam( + "controlnet_cond", + required=True, + type_hint=torch.Tensor, + description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "conditioning_scale", + type_hint=float, + description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "guess_mode", + required=True, + type_hint=bool, + description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "controlnet_keep", + required=True, + type_hint=list[float], + description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", + ), + InputParam( + "timestep_cond", + type_hint=torch.Tensor | None, + description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds, " + "add_time_ids/negative_add_time_ids, " + "pooled_prompt_embeds/negative_pooled_prompt_embeds, " + "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)." + "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + InputParam( + kwargs_type="controlnet_kwargs", + description=( + "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )" + "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + extra_controlnet_kwargs = self.prepare_extra_kwargs( + components.controlnet.forward, **block_state.controlnet_kwargs + ) + + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_inputs = { + "prompt_embeds": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "time_ids": ( + getattr(block_state, "add_time_ids", None), + getattr(block_state, "negative_add_time_ids", None), + ), + "text_embeds": ( + getattr(block_state, "pooled_prompt_embeds", None), + getattr(block_state, "negative_pooled_prompt_embeds", None), + ), + "image_embeds": ( + getattr(block_state, "ip_adapter_embeds", None), + getattr(block_state, "negative_ip_adapter_embeds", None), + ), + } + + # cond_scale for the timestep (controlnet input) + if isinstance(block_state.controlnet_keep[i], list): + block_state.cond_scale = [ + c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i]) + ] + else: + controlnet_cond_scale = block_state.conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i] + + # default controlnet output/unet input for guess mode + conditional path + block_state.down_block_res_samples_zeros = None + block_state.mid_block_res_sample_zeros = None + + # guided denoiser step + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs(guider_inputs) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.unet) + + # Prepare additional conditionings + added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None: + added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds + + # Prepare controlnet additional conditionings + controlnet_added_cond_kwargs = { + "text_embeds": guider_state_batch.text_embeds, + "time_ids": guider_state_batch.time_ids, + } + # run controlnet for the guidance batch + if block_state.guess_mode and not components.guider.is_conditional: + # guider always run uncond batch first, so these tensors should be set already + down_block_res_samples = block_state.down_block_res_samples_zeros + mid_block_res_sample = block_state.mid_block_res_sample_zeros + else: + down_block_res_samples, mid_block_res_sample = components.controlnet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + controlnet_cond=block_state.controlnet_cond, + conditioning_scale=block_state.cond_scale, + guess_mode=block_state.guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + **extra_controlnet_kwargs, + ) + + # assign it to block_state so it will be available for the uncond guidance batch + if block_state.down_block_res_samples_zeros is None: + block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples] + if block_state.mid_block_res_sample_zeros is None: + block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample) + + # Predict the noise + # store the noise_pred in guider_state_batch so we can apply guidance across all batches + guider_state_batch.noise_pred = components.unet( + block_state.scaled_latents, + t, + encoder_hidden_states=guider_state_batch.prompt_embeds, + timestep_cond=block_state.timestep_cond, + cross_attention_kwargs=block_state.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + components.guider.cleanup_models(components.unet) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +# loop step (3): scheduler step to update latents +class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + # YiYi TODO: move this out of here + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + return components, block_state + + +# loop step (3): scheduler step to update latents (with inpainting) +class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents (for inpainting workflow only). " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("eta", default=0.0), + InputParam("generator"), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "mask", + type_hint=torch.Tensor | None, + description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.", + ), + InputParam( + "noise", + type_hint=torch.Tensor | None, + description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor | None, + description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @staticmethod + def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs): + accepted_kwargs = set(inspect.signature(func).parameters.keys()) + extra_kwargs = {} + for key, value in kwargs.items(): + if key in accepted_kwargs and key not in exclude_kwargs: + extra_kwargs[key] = value + + return extra_kwargs + + def check_inputs(self, components, block_state): + if components.num_channels_unet == 4: + if block_state.image_latents is None: + raise ValueError(f"image_latents is required for this step {self.__class__.__name__}") + if block_state.mask is None: + raise ValueError(f"mask is required for this step {self.__class__.__name__}") + if block_state.noise is None: + raise ValueError(f"noise is required for this step {self.__class__.__name__}") + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): + self.check_inputs(components, block_state) + + # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + block_state.extra_step_kwargs = self.prepare_extra_kwargs( + components.scheduler.step, generator=block_state.generator, eta=block_state.eta + ) + + # Perform scheduler step using the predicted output + block_state.latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + **block_state.extra_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != block_state.latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + block_state.latents = block_state.latents.to(block_state.latents_dtype) + + # adjust latent for inpainting + if components.num_channels_unet == 4: + block_state.init_latents_proper = block_state.image_latents + if i < len(block_state.timesteps) - 1: + block_state.noise_timestep = block_state.timesteps[i + 1] + block_state.init_latents_proper = components.scheduler.add_noise( + block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep]) + ) + + block_state.latents = ( + 1 - block_state.mask + ) * block_state.init_latents_proper + block_state.mask * block_state.latents + + return components, block_state + + +# the loop wrapper that iterates over the timesteps +class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("scheduler", EulerDiscreteScheduler), + ComponentSpec("unet", UNet2DConditionModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False + if block_state.disable_guidance: + components.guider.disable() + else: + components.guider.enable() + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +# composing the denoising loops +class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports both text2img and img2img tasks." + ) + + +# control_cond +class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLLoopAfterDenoiser`\n" + "This block supports using controlnet for both text2img and img2img tasks." + ) + + +# mask +class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only). \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block onlysupports inpainting tasks." + ) + + +# control_cond + mask +class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): + block_classes = [ + StableDiffusionXLInpaintLoopBeforeDenoiser, + StableDiffusionXLControlNetLoopDenoiser, + StableDiffusionXLInpaintLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n" + "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n" + " - `StableDiffusionXLControlNetLoopDenoiser`\n" + " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n" + "This block only supports using controlnet for inpainting tasks." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..8387ae7bd6b6d06bee9e1b9b0bf63a6f28191781 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -0,0 +1,885 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...utils import ( + USE_PEFT_BACKEND, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import StableDiffusionXLModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return ( + "IP Adapter step that prepares ip adapter image embeddings.\n" + "Note that this step only prepares the embeddings - in order for it to work correctly, " + "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n" + "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" + " for more details" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("image_encoder", CLIPVisionModelWithProjection), + ComponentSpec( + "feature_extractor", + CLIPImageProcessor, + config=FrozenDict({"size": 224, "crop_size": 224}), + default_creation_method="from_config", + ), + ComponentSpec("unet", UNet2DConditionModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "ip_adapter_image", + PipelineImageInput, + required=True, + description="The image(s) to be used as ip adapter", + ) + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"), + OutputParam( + "negative_ip_adapter_embeds", + type_hint=torch.Tensor, + description="Negative IP adapter image embeddings", + ), + ] + + @staticmethod + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components + def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(components.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = components.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = components.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = components.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + components, + ip_adapter_image, + ip_adapter_image_embeds, + device, + num_images_per_prompt, + prepare_unconditional_embeds, + ): + image_embeds = [] + if prepare_unconditional_embeds: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + components, single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if prepare_unconditional_embeds: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if prepare_unconditional_embeds: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if prepare_unconditional_embeds: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds( + components, + ip_adapter_image=block_state.ip_adapter_image, + ip_adapter_image_embeds=None, + device=block_state.device, + num_images_per_prompt=1, + prepare_unconditional_embeds=block_state.prepare_unconditional_embeds, + ) + if block_state.prepare_unconditional_embeds: + block_state.negative_ip_adapter_embeds = [] + for i, image_embeds in enumerate(block_state.ip_adapter_embeds): + negative_image_embeds, image_embeds = image_embeds.chunk(2) + block_state.negative_ip_adapter_embeds.append(negative_image_embeds) + block_state.ip_adapter_embeds[i] = image_embeds + + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", CLIPTextModel), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ConfigSpec("force_zeros_for_empty_prompt", True)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("cross_attention_kwargs"), + InputParam("clip_skip"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative pooled text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + elif block_state.prompt_2 is not None and ( + not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list) + ): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder, lora_scale) + else: + scale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale) + else: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = ( + [components.tokenizer, components.tokenizer_2] + if components.tokenizer is not None + else [components.tokenizer_2] + ) + text_encoders = ( + [components.text_encoder, components.text_encoder_2] + if components.text_encoder is not None + else [components.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + prompt = components.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt + if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(components, TextualInversionLoaderMixin): + negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if components.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if components.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=components.text_encoder_2.dtype, device=device + ) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if prepare_unconditional_embeds: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if components.text_encoder is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + block_state.text_encoder_lora_scale = ( + block_state.cross_attention_kwargs.get("scale", None) + if block_state.cross_attention_kwargs is not None + else None + ) + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + block_state.pooled_prompt_embeds, + block_state.negative_pooled_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.prompt_2, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + block_state.negative_prompt_2, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + lora_scale=block_state.text_encoder_lora_scale, + clip_skip=block_state.clip_skip, + ) + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def description(self) -> str: + return "Vae Encoder step that encode the input image into a latent representation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + InputParam( + "preprocess_kwargs", + type_hint=dict | None, + description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="The latents representing the reference image for image-to-image/inpainting generation", + ) + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} + block_state.device = components._execution_device + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + + image = components.image_processor.preprocess( + block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs + ) + image = image.to(device=block_state.device, dtype=block_state.dtype) + block_state.batch_size = image.shape[0] + + # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) + if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." + ) + + block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator) + + self.set_block_state(state, block_state) + + return components, state + + +class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-xl" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ComponentSpec( + "mask_processor", + VaeImageProcessor, + config=FrozenDict( + {"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True} + ), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Vae encoder step that prepares the image and mask for the inpainting process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height"), + InputParam("width"), + InputParam("image", required=True), + InputParam("mask_image", required=True), + InputParam("padding_mask_crop"), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image" + ), + OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"), + OutputParam( + "masked_image_latents", + type_hint=torch.Tensor, + description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)", + ), + OutputParam( + "crops_coords", + type_hint=tuple[int, int] | None, + description="The crop coordinates to use for the preprocess/postprocess of the image and mask", + ), + ] + + # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components + # YiYi TODO: update the _encode_vae_image so that we can use #Coped from + def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator): + latents_mean = latents_std = None + if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None: + latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None: + latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1) + + dtype = image.dtype + if components.vae.config.force_upcast: + image = image.float() + components.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(components.vae.encode(image), generator=generator) + + if components.vae.config.force_upcast: + components.vae.to(dtype) + + image_latents = image_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype) + latents_std = latents_std.to(device=image_latents.device, dtype=dtype) + image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + image_latents = components.vae.config.scaling_factor * image_latents + + return image_latents + + # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + # do not accept do_classifier_free_guidance + def prepare_mask_latents( + self, components, mask, masked_image, batch_size, height, width, dtype, device, generator + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype + block_state.device = components._execution_device + + if block_state.height is None: + block_state.height = components.default_height + if block_state.width is None: + block_state.width = components.default_width + + if block_state.padding_mask_crop is not None: + block_state.crops_coords = components.mask_processor.get_crop_region( + block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop + ) + block_state.resize_mode = "fill" + else: + block_state.crops_coords = None + block_state.resize_mode = "default" + + image = components.image_processor.preprocess( + block_state.image, + height=block_state.height, + width=block_state.width, + crops_coords=block_state.crops_coords, + resize_mode=block_state.resize_mode, + ) + image = image.to(dtype=torch.float32) + + mask = components.mask_processor.preprocess( + block_state.mask_image, + height=block_state.height, + width=block_state.width, + resize_mode=block_state.resize_mode, + crops_coords=block_state.crops_coords, + ) + block_state.masked_image = image * (mask < 0.5) + + block_state.batch_size = image.shape[0] + image = image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator) + + # 7. Prepare mask latent variables + block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( + components, + mask, + block_state.masked_image, + block_state.batch_size, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_stable_diffusion_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a18e51477716619d43b1bd2933b86fed53bf88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_stable_diffusion_xl.py @@ -0,0 +1,512 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + StableDiffusionXLControlNetInputStep, + StableDiffusionXLControlNetUnionInputStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLInputStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLSetTimestepsStep, +) +from .decoders import ( + StableDiffusionXLDecodeStep, + StableDiffusionXLInpaintOverlayMaskStep, +) +from .denoise import ( + StableDiffusionXLControlNetDenoiseStep, + StableDiffusionXLDenoiseStep, + StableDiffusionXLInpaintControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, +) +from .encoders import ( + StableDiffusionXLInpaintVaeEncoderStep, + StableDiffusionXLIPAdapterStep, + StableDiffusionXLTextEncoderStep, + StableDiffusionXLVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto blocks & sequential blocks & mappings + + +# vae encoder (run before before_denoise) +class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if neither `mask_image` nor `image` is provided, step will be skipped." + ) + + +# optional ip-adapter (run before input step) +class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLIPAdapterStep] + block_names = ["ip_adapter"] + block_trigger_inputs = ["ip_adapter_image"] + + @property + def description(self): + return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n" + + +# before_denoise: text2img +class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLSetTimestepsStep, + StableDiffusionXLPrepareLatentsStep, + StableDiffusionXLPrepareAdditionalConditioningStep, + ] + block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: img2img +class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLImg2ImgPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: inpainting +class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLImg2ImgSetTimestepsStep, + StableDiffusionXLInpaintPrepareLatentsStep, + StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep, + ] + block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n" + + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n" + + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n" + ) + + +# before_denoise: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLInpaintBeforeDenoiseStep, + StableDiffusionXLImg2ImgBeforeDenoiseStep, + StableDiffusionXLBeforeDenoiseStep, + ] + block_names = ["inpaint", "img2img", "text2img"] + block_trigger_inputs = ["mask", "image_latents", None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n" + + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n" + + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n" + + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n" + ) + + +# optional controlnet input step (after before_denoise, before denoise) +# works for both controlnet and controlnet_union +class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep] + block_names = ["controlnet_union", "controlnet"] + block_trigger_inputs = ["control_mode", "control_image"] + + @property + def description(self): + return ( + "Controlnet Input step that prepare the controlnet input.\n" + + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n" + + " (it should be called right before the denoise step)" + + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n" + + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided." + + " - if neither `control_mode` nor `control_image` is provided, step will be skipped." + ) + + +# denoise: controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep] + block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"] + block_trigger_inputs = ["mask", "controlnet_cond"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents with controlnet. " + "This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks." + "This block should not be used without a controlnet_cond input" + " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided." + " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided." + " - If neither mask nor controlnet_cond are provided, step will be skipped." + ) + + +# denoise: all task with or without controlnet (text2img, img2img, inpainting) +class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + StableDiffusionXLAutoControlNetDenoiseStep, + StableDiffusionXLInpaintDenoiseStep, + StableDiffusionXLDenoiseStep, + ] + block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] + block_trigger_inputs = ["controlnet_cond", "mask", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet." + " - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)." + " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)." + " - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)." + ) + + +# decode: inpaint +class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks): + block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep] + block_names = ["decode", "mask_overlay"] + + @property + def description(self): + return ( + "Inpaint decode step that decode the denoised latents into images outputs.\n" + + "This is a sequential pipeline blocks:\n" + + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image" + ) + + +# decode: all task (text2img, img2img, inpainting) +class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks): + block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep] + block_names = ["inpaint", "non-inpaint"] + block_trigger_inputs = ["padding_mask_crop", None] + + @property + def description(self): + return ( + "Decode step that decode the denoised latents into images outputs.\n" + + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided." + ) + + +class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + StableDiffusionXLInputStep, + StableDiffusionXLAutoBeforeDenoiseStep, + StableDiffusionXLAutoControlNetInputStep, + StableDiffusionXLAutoDenoiseStep, + ] + block_names = ["input", "before_denoise", "controlnet_input", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `StableDiffusionXLInputStep` (input) standardizes the inputs for the denoising step.\n" + + " - `StableDiffusionXLAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + + " - `StableDiffusionXLAutoControlNetInputStep` (controlnet_input) prepares the controlnet input.\n" + + " - `StableDiffusionXLAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" + + "This step support text-to-image, image-to-image, inpainting, with or without controlnet/controlnet_union/ip_adapter for Stable Diffusion XL:\n" + + "- for image-to-image generation, you need to provide `image_latents`\n" + + "- for inpainting, you need to provide `mask_image` and `image_latents`\n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + + "- to run the ip_adapter workflow, you need to load ip_adapter into your unet and provide `ip_adapter_embeds`\n" + + "- for text-to-image generation, all you need to provide is prompt embeddings\n" + ) + + +# ip-adapter, controlnet, text2img, img2img, inpainting +# auto_docstring +class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion + XL. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + - `inpainting`: requires `mask_image`, `image`, `prompt` + - `controlnet_text2image`: requires `control_image`, `prompt` + - `controlnet_image2image`: requires `control_image`, `image`, `prompt` + - `controlnet_inpainting`: requires `control_image`, `mask_image`, `image`, `prompt` + - `controlnet_union_text2image`: requires `control_image`, `control_mode`, `prompt` + - `controlnet_union_image2image`: requires `control_image`, `control_mode`, `image`, `prompt` + - `controlnet_union_inpainting`: requires `control_image`, `control_mode`, `mask_image`, `image`, `prompt` + - `ip_adapter_text2image`: requires `ip_adapter_image`, `prompt` + - `ip_adapter_image2image`: requires `ip_adapter_image`, `image`, `prompt` + - `ip_adapter_inpainting`: requires `ip_adapter_image`, `mask_image`, `image`, `prompt` + - `ip_adapter_controlnet_text2image`: requires `ip_adapter_image`, `control_image`, `prompt` + - `ip_adapter_controlnet_image2image`: requires `ip_adapter_image`, `control_image`, `image`, `prompt` + - `ip_adapter_controlnet_inpainting`: requires `ip_adapter_image`, `control_image`, `mask_image`, `image`, + `prompt` + - `ip_adapter_controlnet_union_text2image`: requires `ip_adapter_image`, `control_image`, `control_mode`, + `prompt` + - `ip_adapter_controlnet_union_image2image`: requires `ip_adapter_image`, `control_image`, `control_mode`, + `image`, `prompt` + - `ip_adapter_controlnet_union_inpainting`: requires `ip_adapter_image`, `control_image`, `control_mode`, + `mask_image`, `image`, `prompt` + + Components: + text_encoder (`CLIPTextModel`) text_encoder_2 (`CLIPTextModelWithProjection`) tokenizer (`CLIPTokenizer`) + tokenizer_2 (`CLIPTokenizer`) guider (`ClassifierFreeGuidance`) image_encoder + (`CLIPVisionModelWithProjection`) feature_extractor (`CLIPImageProcessor`) unet (`UNet2DConditionModel`) vae + (`AutoencoderKL`) image_processor (`VaeImageProcessor`) mask_processor (`VaeImageProcessor`) scheduler + (`EulerDiscreteScheduler`) controlnet (`ControlNetUnionModel`) control_image_processor (`VaeImageProcessor`) + + Configs: + force_zeros_for_empty_prompt (default: True) requires_aesthetics_score (default: False) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + prompt_2 (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + negative_prompt_2 (`None`, *optional*): + TODO: Add description. + cross_attention_kwargs (`None`, *optional*): + TODO: Add description. + clip_skip (`None`, *optional*): + TODO: Add description. + ip_adapter_image (`Image | ndarray | Tensor | list | list | list`, *optional*): + The image(s) to be used as ip adapter + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image (`None`, *optional*): + TODO: Add description. + mask_image (`None`, *optional*): + TODO: Add description. + padding_mask_crop (`None`, *optional*): + TODO: Add description. + dtype (`dtype`, *optional*): + The dtype of the model inputs + generator (`None`, *optional*): + TODO: Add description. + preprocess_kwargs (`dict | NoneType`, *optional*): + A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under + `self.image_processor` in [diffusers.image_processor.VaeImageProcessor] + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + ip_adapter_embeds (`list`, *optional*): + Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step. + negative_ip_adapter_embeds (`list`, *optional*): + Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + denoising_end (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.3): + TODO: Add description. + denoising_start (`None`, *optional*): + TODO: Add description. + latents (`None`): + TODO: Add description. + image_latents (`Tensor`, *optional*): + The latents representing the reference image for image-to-image/inpainting generation. Can be generated + in vae_encode step. + mask (`Tensor`, *optional*): + The mask for the inpainting generation. Can be generated in vae_encode step. + masked_image_latents (`Tensor`, *optional*): + The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be + generated in vae_encode step. + original_size (`None`, *optional*): + TODO: Add description. + target_size (`None`, *optional*): + TODO: Add description. + negative_original_size (`None`, *optional*): + TODO: Add description. + negative_target_size (`None`, *optional*): + TODO: Add description. + crops_coords_top_left (`None`, *optional*, defaults to (0, 0)): + TODO: Add description. + negative_crops_coords_top_left (`None`, *optional*, defaults to (0, 0)): + TODO: Add description. + aesthetic_score (`None`, *optional*, defaults to 6.0): + TODO: Add description. + negative_aesthetic_score (`None`, *optional*, defaults to 2.0): + TODO: Add description. + control_image (`None`, *optional*): + TODO: Add description. + control_mode (`None`, *optional*): + TODO: Add description. + control_guidance_start (`None`, *optional*, defaults to 0.0): + TODO: Add description. + control_guidance_end (`None`, *optional*, defaults to 1.0): + TODO: Add description. + controlnet_conditioning_scale (`None`, *optional*, defaults to 1.0): + TODO: Add description. + guess_mode (`None`, *optional*, defaults to False): + TODO: Add description. + crops_coords (`tuple | NoneType`, *optional*): + The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can + be generated in vae_encode step. + controlnet_cond (`Tensor`, *optional*): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + conditioning_scale (`float`, *optional*): + The controlnet conditioning scale value to use for the denoising process. Can be generated in + prepare_controlnet_inputs step. + controlnet_keep (`list`, *optional*): + The controlnet keep values to use for the denoising process. Can be generated in + prepare_controlnet_inputs step. + **denoiser_input_fields (`None`, *optional*): + All conditional model inputs that need to be prepared with guider. It should contain + prompt_embeds/negative_prompt_embeds, add_time_ids/negative_add_time_ids, + pooled_prompt_embeds/negative_pooled_prompt_embeds, and ip_adapter_embeds/negative_ip_adapter_embeds + (optional).please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when + they are created and added to the pipeline state + eta (`None`, *optional*, defaults to 0.0): + TODO: Add description. + output_type (`None`, *optional*, defaults to pil): + TODO: Add description. + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [ + StableDiffusionXLTextEncoderStep, + StableDiffusionXLAutoIPAdapterStep, + StableDiffusionXLAutoVaeEncoderStep, + StableDiffusionXLCoreDenoiseStep, + StableDiffusionXLAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "ip_adapter", + "vae_encoder", + "denoise", + "decode", + ] + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + "inpainting": {"mask_image": True, "image": True, "prompt": True}, + "controlnet_text2image": {"control_image": True, "prompt": True}, + "controlnet_image2image": {"control_image": True, "image": True, "prompt": True}, + "controlnet_inpainting": {"control_image": True, "mask_image": True, "image": True, "prompt": True}, + "controlnet_union_text2image": {"control_image": True, "control_mode": True, "prompt": True}, + "controlnet_union_image2image": {"control_image": True, "control_mode": True, "image": True, "prompt": True}, + "controlnet_union_inpainting": { + "control_image": True, + "control_mode": True, + "mask_image": True, + "image": True, + "prompt": True, + }, + "ip_adapter_text2image": {"ip_adapter_image": True, "prompt": True}, + "ip_adapter_image2image": {"ip_adapter_image": True, "image": True, "prompt": True}, + "ip_adapter_inpainting": {"ip_adapter_image": True, "mask_image": True, "image": True, "prompt": True}, + "ip_adapter_controlnet_text2image": {"ip_adapter_image": True, "control_image": True, "prompt": True}, + "ip_adapter_controlnet_image2image": { + "ip_adapter_image": True, + "control_image": True, + "image": True, + "prompt": True, + }, + "ip_adapter_controlnet_inpainting": { + "ip_adapter_image": True, + "control_image": True, + "mask_image": True, + "image": True, + "prompt": True, + }, + "ip_adapter_controlnet_union_text2image": { + "ip_adapter_image": True, + "control_image": True, + "control_mode": True, + "prompt": True, + }, + "ip_adapter_controlnet_union_image2image": { + "ip_adapter_image": True, + "control_image": True, + "control_mode": True, + "image": True, + "prompt": True, + }, + "ip_adapter_controlnet_union_inpainting": { + "ip_adapter_image": True, + "control_image": True, + "control_mode": True, + "mask_image": True, + "image": True, + "prompt": True, + }, + } + + @property + def description(self): + return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..209e2b11814f8e1242db1de79e79977add1ad7a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -0,0 +1,360 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import PIL +import torch + +from ...image_processor import PipelineImageInput +from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from ...utils import logging +from ..modular_pipeline import ModularPipeline +from ..modular_pipeline_utils import InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? +# YiYi Notes: model specific components: +## (1) it should inherit from ModularPipeline +## (2) acts like a container that holds components and configs +## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents +## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) +## (5) how to use together with Components_manager? +class StableDiffusionXLModularPipeline( + ModularPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + ModularIPAdapterMixin, +): + """ + A ModularPipeline for Stable Diffusion XL. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "StableDiffusionXLAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + default_sample_size = 128 + if hasattr(self, "unet") and self.unet is not None: + default_sample_size = self.unet.config.sample_size + return default_sample_size + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + # YiYi TODO: change to num_channels_latents + @property + def num_channels_unet(self): + num_channels_unet = 4 + if hasattr(self, "unet") and self.unet is not None: + num_channels_unet = self.unet.config.in_channels + return num_channels_unet + + @property + def num_channels_latents(self): + num_channels_latents = 4 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + return num_channels_latents + + +# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks +# auto_docstring +SDXL_INPUTS_SCHEMA = { + "prompt": InputParam( + "prompt", type_hint=str | list[str], description="The prompt or prompts to guide the image generation" + ), + "prompt_2": InputParam( + "prompt_2", + type_hint=str | list[str], + description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", + ), + "negative_prompt": InputParam( + "negative_prompt", + type_hint=str | list[str], + description="The prompt or prompts not to guide the image generation", + ), + "negative_prompt_2": InputParam( + "negative_prompt_2", + type_hint=str | list[str], + description="The negative prompt or prompts for text_encoder_2", + ), + "cross_attention_kwargs": InputParam( + "cross_attention_kwargs", + type_hint=dict | None, + description="Kwargs dictionary passed to the AttentionProcessor", + ), + "clip_skip": InputParam( + "clip_skip", type_hint=int | None, description="Number of layers to skip in CLIP text encoder" + ), + "image": InputParam( + "image", + type_hint=PipelineImageInput, + required=True, + description="The image(s) to modify for img2img or inpainting", + ), + "mask_image": InputParam( + "mask_image", + type_hint=PipelineImageInput, + required=True, + description="Mask image for inpainting, white pixels will be repainted", + ), + "generator": InputParam( + "generator", + type_hint=torch.Generator | list[torch.Generator] | None, + description="Generator(s) for deterministic generation", + ), + "height": InputParam("height", type_hint=int | None, description="Height in pixels of the generated image"), + "width": InputParam("width", type_hint=int | None, description="Width in pixels of the generated image"), + "num_images_per_prompt": InputParam( + "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" + ), + "num_inference_steps": InputParam( + "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" + ), + "timesteps": InputParam( + "timesteps", type_hint=torch.Tensor | None, description="Custom timesteps for the denoising process" + ), + "sigmas": InputParam( + "sigmas", type_hint=torch.Tensor | None, description="Custom sigmas for the denoising process" + ), + "denoising_end": InputParam( + "denoising_end", + type_hint=float | None, + description="Fraction of denoising process to complete before termination", + ), + # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 + "strength": InputParam( + "strength", type_hint=float, default=0.3, description="How much to transform the reference image" + ), + "denoising_start": InputParam( + "denoising_start", type_hint=float | None, description="Starting point of the denoising process" + ), + "latents": InputParam( + "latents", type_hint=torch.Tensor | None, description="Pre-generated noisy latents for image generation" + ), + "padding_mask_crop": InputParam( + "padding_mask_crop", + type_hint=tuple[int, int] | None, + description="Size of margin in crop for image and mask", + ), + "original_size": InputParam( + "original_size", + type_hint=tuple[int, int] | None, + description="Original size of the image for SDXL's micro-conditioning", + ), + "target_size": InputParam( + "target_size", type_hint=tuple[int, int] | None, description="Target size for SDXL's micro-conditioning" + ), + "negative_original_size": InputParam( + "negative_original_size", + type_hint=tuple[int, int] | None, + description="Negative conditioning based on image resolution", + ), + "negative_target_size": InputParam( + "negative_target_size", + type_hint=tuple[int, int] | None, + description="Negative conditioning based on target resolution", + ), + "crops_coords_top_left": InputParam( + "crops_coords_top_left", + type_hint=tuple[int, int], + default=(0, 0), + description="Top-left coordinates for SDXL's micro-conditioning", + ), + "negative_crops_coords_top_left": InputParam( + "negative_crops_coords_top_left", + type_hint=tuple[int, int], + default=(0, 0), + description="Negative conditioning crop coordinates", + ), + "aesthetic_score": InputParam( + "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" + ), + "negative_aesthetic_score": InputParam( + "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" + ), + "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), + "output_type": InputParam( + "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" + ), + "ip_adapter_image": InputParam( + "ip_adapter_image", + type_hint=PipelineImageInput, + required=True, + description="Image(s) to be used as IP adapter", + ), + "control_image": InputParam( + "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" + ), + "control_guidance_start": InputParam( + "control_guidance_start", + type_hint=float | list[float], + default=0.0, + description="When ControlNet starts applying", + ), + "control_guidance_end": InputParam( + "control_guidance_end", + type_hint=float | list[float], + default=1.0, + description="When ControlNet stops applying", + ), + "controlnet_conditioning_scale": InputParam( + "controlnet_conditioning_scale", + type_hint=float | list[float], + default=1.0, + description="Scale factor for ControlNet outputs", + ), + "guess_mode": InputParam( + "guess_mode", + type_hint=bool, + default=False, + description="Enables ControlNet encoder to recognize input without prompts", + ), + "control_mode": InputParam( + "control_mode", type_hint=list[int], required=True, description="Control mode for union controlnet" + ), + "prompt_embeds": InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + required=True, + description="Text embeddings used to guide image generation", + ), + "negative_prompt_embeds": InputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": InputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": InputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), + "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "preprocess_kwargs": InputParam( + "preprocess_kwargs", type_hint=dict | None, description="Kwargs for ImageProcessor" + ), + "latent_timestep": InputParam( + "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" + ), + "image_latents": InputParam( + "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" + ), + "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), + "masked_image_latents": InputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "add_time_ids": InputParam( + "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" + ), + "negative_add_time_ids": InputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "crops_coords": InputParam("crops_coords", type_hint=tuple[int] | None, description="Crop coordinates"), + "ip_adapter_embeds": InputParam( + "ip_adapter_embeds", type_hint=list[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": InputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": InputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + required=True, + description="Generated images", + ), +} + + +SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = { + "prompt_embeds": OutputParam( + "prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation" + ), + "negative_prompt_embeds": OutputParam( + "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" + ), + "pooled_prompt_embeds": OutputParam( + "pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings" + ), + "negative_pooled_prompt_embeds": OutputParam( + "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" + ), + "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"), + "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), + "image_latents": OutputParam( + "image_latents", type_hint=torch.Tensor, description="Latents representing reference image" + ), + "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"), + "masked_image_latents": OutputParam( + "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" + ), + "crops_coords": OutputParam("crops_coords", type_hint=tuple[int] | None, description="Crop coordinates"), + "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"), + "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"), + "latent_timestep": OutputParam( + "latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep" + ), + "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"), + "negative_add_time_ids": OutputParam( + "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" + ), + "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), + "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), + "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), + "ip_adapter_embeds": OutputParam( + "ip_adapter_embeds", type_hint=list[torch.Tensor], description="Image embeddings for IP-Adapter" + ), + "negative_ip_adapter_embeds": OutputParam( + "negative_ip_adapter_embeds", + type_hint=list[torch.Tensor], + description="Negative image embeddings for IP-Adapter", + ), + "images": OutputParam( + "images", + type_hint=list[PIL.Image.Image] | list[torch.Tensor] | list[np.array], + description="Generated images", + ), +} + + +SDXL_OUTPUTS_SCHEMA = { + "images": OutputParam( + "images", + type_hint=tuple[list[PIL.Image.Image] | list[torch.Tensor] | list[np.array]] | StableDiffusionXLPipelineOutput, + description="The final generated images", + ) +} diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..284b6c9fa436619460ce456dab54e6fd63f7906f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__init__.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_wan"] = ["WanBlocks"] + _import_structure["modular_blocks_wan22"] = ["Wan22Blocks"] + _import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"] + _import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"] + _import_structure["modular_pipeline"] = [ + "Wan22Image2VideoModularPipeline", + "Wan22ModularPipeline", + "WanImage2VideoModularPipeline", + "WanModularPipeline", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_wan import WanBlocks + from .modular_blocks_wan22 import Wan22Blocks + from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks + from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks + from .modular_pipeline import ( + Wan22Image2VideoModularPipeline, + Wan22ModularPipeline, + WanImage2VideoModularPipeline, + WanModularPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1db431147493c05fcf9bafc3026aa69943b90677 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/before_denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/before_denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1bec48b5f04f439819c775f3faba1e8182f6f4f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/before_denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/decoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/decoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e63423e22e9837254d30f2817a2e9ec60f98f63 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/decoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bfa248c290a834cbabed703a6a6d64afab21db8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/encoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/encoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..317206c829eba6e69eac637ce6c4ced774e91ed8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/encoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e34e0c1a690927bb6704ce5bbbb8286a8426394b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan22.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan22.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c43308923748dc7c6fcf35e3cbcdb6a490dcdc55 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan22.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan22_i2v.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan22_i2v.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..971b500ff10ad22a8bb8f428a48270cd0a27684c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan22_i2v.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan_i2v.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan_i2v.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..525bf4a5ff252d612e8b06b9be3adfcec4476537 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_blocks_wan_i2v.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5a385edc5fdf2f55c77f70b8cd7322331488582 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/__pycache__/modular_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/before_denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..398b9665522c8877817c88ab4aeb034d5c4f3464 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/before_denoise.py @@ -0,0 +1,554 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...models import WanTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times + - If batch size equals batch_size: repeat each element num_videos_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_videos_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_videos_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int +) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by + multiplying the latent num_frames/height/width by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension. + Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension) + vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + if latents.ndim != 5: + raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}") + + _, _, num_latent_frames, latent_height, latent_width = latents.shape + + num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1 + height = latent_height * vae_scale_factor_spatial + width = latent_width * vae_scale_factor_spatial + + return num_frames, height, width + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class WanTextInputStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_videos_per_prompt." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_videos_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_videos_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class WanAdditionalInputsStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_condition_latents"], + additional_batch_inputs: list[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (list[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_condition_latents"]. + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to + process image latents and additional batch inputs WanAdditionalInputsStep( + image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_videos_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="num_frames"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + num_frames, height, width = calculate_dimension_from_latents( + image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial + ) + block_state.num_frames = block_state.num_frames or num_frames + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class WanSetTimestepsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + block_state.timesteps, + block_state.sigmas, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + @staticmethod + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // comp.vae_scale_factor_spatial, + int(width) // comp.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 # Wan latents should be torch.float32 for best quality + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.num_frames = block_state.num_frames or components.default_num_frames + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_videos_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + num_frames=block_state.num_frames, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/decoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1a4cf4f3486245ed15c3a753738ae7d0608e4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/decoders.py @@ -0,0 +1,98 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLWan +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanVaeDecoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "output_type", default="np", type_hint=str, description="The output type of the decoded videos" + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "videos", + type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray], + description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + latents = latents.to(vae_dtype) + block_state.videos = components.vae.decode(latents, return_dict=False)[0] + + output_type = getattr(block_state, "output_type", "np") + block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type=output_type) + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..2f51f353012ea6452610d327087bff3db867a99c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/denoise.py @@ -0,0 +1,544 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import WanTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + return components, block_state + + +class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "image_condition_latents", + required=True, + type_hint=torch.Tensor, + description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat( + [block_state.latents, block_state.image_condition_latents], dim=1 + ).to(block_state.dtype) + return components, block_state + + +class WanLoopDenoiser(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + guider_input_fields: dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam("attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class Wan22LoopDenoiser(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + guider_input_fields: dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either: + + - A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + `encoder_hidden_states`. + - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec( + "guider_2", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ComponentSpec("transformer_2", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def expected_configs(self) -> list[ConfigSpec]: + return [ + ConfigSpec( + name="boundary_ratio", + default=0.875, + description="The boundary ratio to divide the denoising loop into high noise and low noise stages.", + ), + ] + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam("attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps + if t >= boundary_timestep: + block_state.current_model = components.transformer + block_state.guider = components.guider + else: + block_state.current_model = components.transformer_2 + block_state.guider = components.guider_2 + + block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + block_state.guider.prepare_models(block_state.current_model) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = block_state.current_model( + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + block_state.guider.cleanup_models(block_state.current_model) + + # Perform guidance + block_state.noise_pred = block_state.guider(guider_state)[0] + + return components, block_state + + +class WanLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class WanDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports text-to-video tasks for wan2.1." + ) + + +class Wan22DenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" + " - `Wan22LoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports text-to-video tasks for Wan2.2." + ) + + +class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_image": "image_embeds", + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for wan2.1." + ) + + +class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for Wan2.2." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/encoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..3e675a66e4f2b7acbdd09fcbf9a8fd61483c581e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/encoders.py @@ -0,0 +1,758 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html + +import numpy as np +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan +from ...utils import is_ftfy_available, is_torchvision_available, logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +if is_ftfy_available(): + import ftfy + +if is_torchvision_available(): + from torchvision import transforms + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def get_t5_prompt_embeds( + text_encoder: UMT5EncoderModel, + tokenizer: AutoTokenizer, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, +): + dtype = text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds + + +def encode_image( + image: PipelineImageInput, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + device: torch.device | None = None, +): + image = image_processor(images=image, return_tensors="pt").to(device) + image_embeds = image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + video_tensor: torch.Tensor, + vae: AutoencoderKLWan, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(video_tensor, torch.Tensor): + raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.") + + if isinstance(generator, list) and len(generator) != video_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}." + ) + + video_tensor = video_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + video_latents = [ + retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(video_tensor.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax") + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(video_latents.device, video_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to( + video_latents.device, video_latents.dtype + ) + video_latents = (video_latents - latents_mean) * latents_std + + return video_latents + + +class WanTextEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", UMT5EncoderModel), + ComponentSpec("tokenizer", AutoTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: torch.device | None = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class WanImageResizeStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height", type_hint=int, default=480), + InputParam("width", type_hint=int, default=832), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("resized_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + max_area = block_state.height * block_state.width + + image = block_state.image + aspect_ratio = image.height / image.width + mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial + block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + block_state.resized_image = image.resize((block_state.width, block_state.height)) + + self.set_block_state(state, block_state) + return components, state + + +class WanImageCropResizeStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Resize step that resize the last_image to the same size of first frame image with center crop." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image" + ), + InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("resized_last_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + height = block_state.resized_image.height + width = block_state.resized_image.width + image = block_state.last_image + + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + resized_image = transforms.functional.center_crop(image, size) + block_state.resized_last_image = resized_image + + self.set_block_state(state, block_state) + return components, state + + +class WanImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + image = block_state.resized_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=image, + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=[first_frame_image, last_frame_image], + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + +class WanVaeEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames", type_hint=int, default=81), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "first_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.resized_image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + image_tensor = components.video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=dtype + ) + + if image_tensor.dim() == 4: + image_tensor = image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [ + image_tensor, + image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width), + ], + dim=2, + ).to(device=device, dtype=dtype) + + block_state.first_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked first frame latents and add it to the latent condition" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("first_frame_latents", type_hint=torch.Tensor | None), + InputParam("num_frames", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=torch.Tensor | None), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames", type_hint=int, default=81), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "first_last_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first and last frame images condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + first_image_tensor = first_image_tensor.unsqueeze(2) + + last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + + last_image_tensor = last_image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [ + first_image_tensor, + first_image_tensor.new_zeros( + first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width + ), + last_image_tensor, + ], + dim=2, + ).to(device=device, dtype=dtype) + + block_state.first_last_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked latents with first and last frames and add it to the latent condition" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("first_last_frame_latents", type_hint=torch.Tensor | None), + InputParam("num_frames", type_hint=int, required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("image_condition_latents", type_hint=torch.Tensor | None), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) + block_state.image_condition_latents = torch.concat( + [mask_lat_size, block_state.first_last_frame_latents], dim=1 + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..b641c6cd7fcc481f50e91631bb747fad8f2b3e13 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan.py @@ -0,0 +1,162 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanDenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise +# auto_docstring +class WanCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# ==================== +# 2. BLOCKS (Wan2.1 text2video) +# ==================== + + +# auto_docstring +class WanBlocks(SequentialPipelineBlocks): + """ + Modular pipeline blocks for Wan2.1. + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer + (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) vae (`AutoencoderKLWan`) video_processor + (`VideoProcessor`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + WanCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for Wan2.1." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan22.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan22.py new file mode 100644 index 0000000000000000000000000000000000000000..9f602c24713b73a30ae4950b254b0e94f9cae750 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan22.py @@ -0,0 +1,176 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22DenoiseStep, +) +from .encoders import ( + WanTextEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + +# inputs(text) -> set_timesteps -> prepare_latents -> denoise + + +# auto_docstring +class Wan22CoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "wan" + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# ==================== +# 2. BLOCKS (Wan2.2 text2video) +# ==================== + + +# auto_docstring +class Wan22Blocks(SequentialPipelineBlocks): + """ + Modular pipeline for text-to-video using Wan2.2. + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer + (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`) + transformer_2 (`WanTransformer3DModel`) vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan" + block_classes = [ + WanTextEncoderStep, + Wan22CoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return "Modular pipeline for text-to-video using Wan2.2." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..8e55b7a50f08ae1f6e78f0084dac38abaa676e5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py @@ -0,0 +1,236 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + Wan22Image2VideoDenoiseStep, +) +from .encoders import ( + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. VAE ENCODER +# ==================== + + +# auto_docstring +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + """ + Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent + representation + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + first_frame_latents (`Tensor`): + video latent representation with the first frame image condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# ==================== +# 2. DENOISE +# ==================== + + +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +# auto_docstring +class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded text and image latent conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + num_frames (`None`, *optional*): + TODO: Add description. + image_condition_latents (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22Image2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return "denoise block that takes encoded text and image latent conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# ==================== +# 3. BLOCKS (Wan2.2 Image2Video) +# ==================== + + +# auto_docstring +class Wan22Image2VideoBlocks(SequentialPipelineBlocks): + """ + Modular pipeline for image-to-video using Wan2.2. + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler + (`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`) + + Configs: + boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low + noise stages. + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanImage2VideoVaeEncoderStep, + Wan22Image2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return "Modular pipeline for image-to-video using Wan2.2." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..c08db62c469a67298946d751326020d45c1ef646 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py @@ -0,0 +1,481 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + WanAdditionalInputsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanVaeDecoderStep +from .denoise import ( + WanImage2VideoDenoiseStep, +) +from .encoders import ( + WanFirstLastFrameImageEncoderStep, + WanFirstLastFrameVaeEncoderStep, + WanImageCropResizeStep, + WanImageEncoderStep, + WanImageResizeStep, + WanPrepareFirstFrameLatentsStep, + WanPrepareFirstLastFrameLatentsStep, + WanTextEncoderStep, + WanVaeEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# ==================== +# 1. IMAGE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +# auto_docstring +class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): + """ + Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings + + Components: + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + image_embeds (`Tensor`): + The image embeddings + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageEncoderStep] + block_names = ["image_resize", "image_encoder"] + + @property + def description(self): + return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" + + +# wan2.1 FLF2V (first and last frame) +# auto_docstring +class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): + """ + FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image + embeddings + + Components: + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`): + The last frameimage + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + image_embeds (`Tensor`): + The image embeddings + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "image_encoder"] + + @property + def description(self): + return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" + + +# wan2.1 Auto Image Encoder +# auto_docstring +class WanAutoImageEncoderStep(AutoPipelineBlocks): + """ + Image Encoder step that encode the image to generate the image embeddingsThis is an auto pipeline block that works + for image2video tasks. - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided. - + `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is + not provided, step will be skipped. + + Components: + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) + + Inputs: + image (`Image`, *optional*): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`, *optional*): + The last frameimage + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + image_embeds (`Tensor`): + The image embeddings + """ + + block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] + block_names = ["flf2v_image_encoder", "image2video_image_encoder"] + block_trigger_inputs = ["last_image", "image"] + model_name = "wan-i2v" + + @property + def description(self): + return ( + "Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# wan2.1 I2V (first frame only) +# auto_docstring +class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks): + """ + Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent + representation + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + first_frame_latents (`Tensor`): + video latent representation with the first frame image condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + """ + + model_name = "wan-i2v" + block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep] + block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +# wan2.1 FLF2V (first and last frame) +# auto_docstring +class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks): + """ + FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the + latent conditions + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`): + The last frameimage + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + first_last_frame_latents (`Tensor`): + video latent representation with the first and last frame images condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + """ + + model_name = "wan-i2v" + block_classes = [ + WanImageResizeStep, + WanImageCropResizeStep, + WanFirstLastFrameVaeEncoderStep, + WanPrepareFirstLastFrameLatentsStep, + ] + block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"] + + @property + def description(self): + return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" + + +# wan2.1 Auto Vae Encoder +# auto_docstring +class WanAutoVaeEncoderStep(AutoPipelineBlocks): + """ + Vae Image Encoder step that encode the image to generate the image latentsThis is an auto pipeline block that works + for image2video tasks. - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided. - + `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is not + provided, step will be skipped. + + Components: + vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`) + + Inputs: + image (`Image`, *optional*): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`, *optional*): + The last frameimage + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + resized_image (`Image`): + TODO: Add description. + resized_last_image (`Image`): + TODO: Add description. + first_last_frame_latents (`Tensor`): + video latent representation with the first and last frame images condition + image_condition_latents (`Tensor | NoneType`): + TODO: Add description. + first_frame_latents (`Tensor`): + video latent representation with the first frame image condition + """ + + model_name = "wan-i2v" + block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep] + block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] + block_trigger_inputs = ["last_image", "image"] + + @property + def description(self): + return ( + "Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise) +# ==================== + + +# wan2.1 I2V core denoise (support both I2V and FLF2V) +# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents) +# auto_docstring +class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded text and image latent conditions and runs the denoising process. + + Components: + transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + num_frames (`None`, *optional*): + TODO: Add description. + image_condition_latents (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + image_embeds (`Tensor`): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `transformer.dtype`) + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanImage2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "denoise", + ] + + @property + def description(self): + return "denoise block that takes encoded text and image latent conditions and runs the denoising process." + + +# ==================== +# 4. BLOCKS (Wan2.1 Image2Video) +# ==================== + + +# wan2.1 Image2Video Auto Blocks +# auto_docstring +class WanImage2VideoAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for image-to-video using Wan. + + Supported workflows: + - `image2video`: requires `image`, `prompt` + - `flf2v`: requires `last_image`, `image`, `prompt` + + Components: + text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) + image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) vae (`AutoencoderKLWan`) + video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler + (`UniPCMultistepScheduler`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + image (`Image`, *optional*): + TODO: Add description. + height (`int`, *optional*, defaults to 480): + TODO: Add description. + width (`int`, *optional*, defaults to 832): + TODO: Add description. + last_image (`Image`, *optional*): + The last frameimage + num_frames (`int`, *optional*, defaults to 81): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_condition_latents (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + timesteps (`None`, *optional*): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + image_embeds (`Tensor`): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + The output type of the decoded videos + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "wan-i2v" + block_classes = [ + WanTextEncoderStep, + WanAutoImageEncoderStep, + WanAutoVaeEncoderStep, + WanImage2VideoCoreDenoiseStep, + WanVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "image_encoder", + "vae_encoder", + "denoise", + "decode", + ] + + _workflow_map = { + "image2video": {"image": True, "prompt": True}, + "flf2v": {"last_image": True, "image": True, "prompt": True}, + } + + @property + def description(self): + return "Auto Modular pipeline for image-to-video using Wan." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_pipeline.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0e52026a51bf9462752c566fb99df0b8de0e3814 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import WanLoraLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanModularPipeline( + ModularPipeline, + StableDiffusionMixin, + WanLoraLoaderMixin, +): + """ + A ModularPipeline for Wan2.1 text2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "WanBlocks" + + @property + def default_height(self): + return self.default_sample_height * self.vae_scale_factor_spatial + + @property + def default_width(self): + return self.default_sample_width * self.vae_scale_factor_spatial + + @property + def default_num_frames(self): + return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1 + + @property + def default_sample_height(self): + return 60 + + @property + def default_sample_width(self): + return 104 + + @property + def default_sample_num_frames(self): + return 21 + + @property + def patch_size_spatial(self): + patch_size_spatial = 2 + if hasattr(self, "transformer") and self.transformer is not None: + patch_size_spatial = self.transformer.config.patch_size[1] + return patch_size_spatial + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def vae_scale_factor_temporal(self): + vae_scale_factor = 4 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** sum(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_transformer(self): + num_channels_transformer = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_transformer = self.transformer.config.in_channels + return num_channels_transformer + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.z_dim + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + @property + def num_train_timesteps(self): + num_train_timesteps = 1000 + if hasattr(self, "scheduler") and self.scheduler is not None: + num_train_timesteps = self.scheduler.config.num_train_timesteps + return num_train_timesteps + + +class WanImage2VideoModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V). + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "WanImage2VideoAutoBlocks" + + +class Wan22ModularPipeline(WanModularPipeline): + """ + A ModularPipeline for Wan2.2 text2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Blocks" + + +class Wan22Image2VideoModularPipeline(Wan22ModularPipeline): + """ + A ModularPipeline for Wan2.2 image2video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Wan22Image2VideoBlocks" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c04008d33052ede093286a8a800b01b960d0e27 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_z_image"] = ["ZImageAutoBlocks"] + _import_structure["modular_pipeline"] = ["ZImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_z_image import ZImageAutoBlocks + from .modular_pipeline import ZImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2747fa0acfc05b09c3525439e6783c052aee62a1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/before_denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/before_denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01251732ad4c4f0e9b8467f381621094fffa24df Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/before_denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/decoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/decoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24b9b97d4b36831db6991f9f4b30537d0f8dee42 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/decoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/denoise.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/denoise.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da98bfcd6fa7141239e37f9b793f7fb19781f48f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/denoise.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/encoders.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/encoders.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30f2b28a0de2d5582cbb529e7993828aaea1ff9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/encoders.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/modular_blocks_z_image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/modular_blocks_z_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cade97b066437f086dc4ae309e5384310caa309 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/modular_blocks_z_image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/modular_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/modular_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..823d99984382804651f77c0c02a441a8d3b3f206 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/__pycache__/modular_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/before_denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/before_denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..8558f2c67f655c907979eba60520a99f2b139a0f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/before_denoise.py @@ -0,0 +1,620 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 dimensions. + Expected shapes: [batch, channels, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension. + By default, it is 16 + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + """ + latent_height, latent_width = latents.shape[2:] + height = latent_height * vae_scale_factor_spatial // 2 + width = latent_width * vae_scale_factor_spatial // 2 + + return height, width + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageTextInputStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=list[torch.Tensor], + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=list[torch.Tensor], + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if not isinstance(block_state.prompt_embeds, list): + raise ValueError( + f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}." + ) + if not isinstance(block_state.negative_prompt_embeds, list): + raise ValueError( + f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}." + ) + if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds): + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but" + f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`" + f" {len(block_state.negative_prompt_embeds)}." + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = len(block_state.prompt_embeds) + block_state.dtype = block_state.prompt_embeds[0].dtype + + if block_state.num_images_per_prompt > 1: + prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)] + block_state.prompt_embeds = prompt_embeds + + if block_state.negative_prompt_embeds is not None: + negative_prompt_embeds = [ + npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt) + ] + block_state.negative_prompt_embeds = negative_prompt_embeds + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageAdditionalInputsStep(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + image_latent_inputs: list[str] = ["image_latents"], + additional_batch_inputs: list[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (list[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_latents"]. + additional_batch_inputs (list[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_latents (default behavior) ZImageAdditionalInputsStep() + + # Configure to process multiple image latent inputs + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + def check_inputs(self, components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("num_inference_steps", default=9), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3] + image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify + + mu = calculate_shift( + image_seq_len, + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + components.scheduler.sigma_min = 0.0 + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=block_state.sigmas, + mu=mu, + ) + + self.set_block_state(state, block_state) + return components, state + + +class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("strength", default=0.6), + ] + + def check_inputs(self, components, block_state): + if block_state.strength < 0.0 or block_state.strength > 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}") + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps) + + t_start = int(max(block_state.num_inference_steps - init_timestep, 0)) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + block_state.timesteps = timesteps + block_state.num_inference_steps = block_state.num_inference_steps - t_start + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("image_latents", required=True), + InputParam("timesteps", required=True), + ] + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/decoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..353253102376c7a0973ae1c457a314ad0cbe8ff5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/decoders.py @@ -0,0 +1,91 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageVaeDecoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + ), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return [ + OutputParam( + "images", + type_hint=list[PIL.Image.Image, list[torch.Tensor], list[np.ndarray]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents.to(vae_dtype) + latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/denoise.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..863df312389a56794663a36082dce143e60bf80c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/denoise.py @@ -0,0 +1,314 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents = block_state.latents.unsqueeze(2).to( + block_state.dtype + ) # [batch_size, num_channels, 1, height, width] + block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width] + + timestep = t.expand(latents.shape[0]).to(block_state.dtype) + timestep = (1000 - timestep) / 1000 + block_state.timestep = timestep + return components, block_state + + +class ZImageLoopDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + guider_input_fields: dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), + ] + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + + def _convert_dtype(v, dtype): + if isinstance(v, torch.Tensor): + return v.to(dtype) + elif isinstance(v, list): + return [_convert_dtype(t, dtype) for t in v] + return v + + cond_kwargs = { + k: _convert_dtype(v, block_state.dtype) + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + model_out_list = components.transformer( + x=block_state.latent_model_input, + t=block_state.timestep, + return_dict=False, + **cond_kwargs, + )[0] + noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) + guider_state_batch.noise_pred = -noise_pred + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class ZImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageDenoiseStep(ZImageDenoiseLoopWrapper): + block_classes = [ + ZImageLoopBeforeDenoiser, + ZImageLoopDenoiser( + guider_input_fields={ + "cap_feats": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + ZImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `ZImageLoopBeforeDenoiser`\n" + " - `ZImageLoopDenoiser`\n" + " - `ZImageLoopAfterDenoiser`\n" + "This block supports text-to-image and image-to-image tasks for Z-Image." + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/encoders.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..06deb823689302c99cad58643350c3f184deb28f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/encoders.py @@ -0,0 +1,343 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import PIL +import torch +from transformers import Qwen2Tokenizer, Qwen3Model + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +if is_ftfy_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_qwen_prompt_embeds( + text_encoder: Qwen3Model, + tokenizer: Qwen2Tokenizer, + prompt: str | list[str], + device: torch.device, + max_sequence_length: int = 512, +) -> list[torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + prompt_embeds_list = [] + + for i in range(len(prompt_embeds)): + prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]]) + + return prompt_embeds_list + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + image_tensor: torch.Tensor, + vae: AutoencoderKL, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(image_tensor, torch.Tensor): + raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.") + + if isinstance(generator, list) and len(generator) != image_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}." + ) + + image_tensor = image_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i]) + for i in range(image_tensor.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class ZImageTextEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=list[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: torch.device | None = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class ZImageVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + image_tensor = components.image_processor.preprocess( + image, height=block_state.height, width=block_state.width + ).to(device=device, dtype=dtype) + + block_state.image_latents = encode_vae_image( + image_tensor=image_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..23e20d55fb1e8e3ed23fcb141ce525b3c6e8270b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py @@ -0,0 +1,334 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + ZImageAdditionalInputsStep, + ZImagePrepareLatentsStep, + ZImagePrepareLatentswithImageStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImageTextInputStep, +) +from .decoders import ZImageVaeDecoderStep +from .denoise import ( + ZImageDenoiseStep, +) +from .encoders import ( + ZImageTextEncoderStep, + ZImageVaeImageEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# ==================== +# 1. DENOISE +# ==================== + + +# text2image: inputs(text) -> set_timesteps -> prepare_latents -> denoise +# auto_docstring +class ZImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`list`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`list`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 9): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + ZImageTextInputStep, + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self): + return "denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# image2image: inputs(text + image_latents) -> prepare_latents -> set_timesteps -> set_timesteps_with_strength -> prepare_latents_with_image -> denoise +# auto_docstring +class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + denoise block that takes encoded text and image latent conditions and runs the denoising process. + + Components: + transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`list`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`list`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 9): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + ZImageTextInputStep, + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImagePrepareLatentswithImageStep, + ZImageDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "set_timesteps_with_strength", + "prepare_latents_with_image", + "denoise", + ] + + @property + def description(self): + return "denoise block that takes encoded text and image latent conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class ZImageAutoDenoiseStep(AutoPipelineBlocks): + """ + Denoise step that iteratively denoise the latents. This is a auto pipeline block that works for text2image and + image2image tasks. - `ZImageCoreDenoiseStep` (text2image) for text2image tasks. - + `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks. - if `image_latents` is provided, + `ZImageImage2ImageCoreDenoiseStep` will be used. + - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used. + + Components: + transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`list`): + Pre-generated text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`list`, *optional*): + Pre-generated negative text embeddings. Can be generated from text_encoder step. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + block_classes = [ + ZImageImage2ImageCoreDenoiseStep, + ZImageCoreDenoiseStep, + ] + block_names = ["image2image", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2image and image2image tasks." + " - `ZImageCoreDenoiseStep` (text2image) for text2image tasks." + " - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks." + + " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n" + + " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n" + ) + + +# auto_docstring +class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): + """ + Vae Image Encoder step that encode the image to generate the image latents + + Components: + vae (`AutoencoderKL`) image_processor (`VaeImageProcessor`) + + Inputs: + image (`Image`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + + Outputs: + image_latents (`Tensor`): + video latent representation with the first frame image condition + """ + + block_classes = [ZImageVaeImageEncoderStep] + block_names = ["vae_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self) -> str: + return "Vae Image Encoder step that encode the image to generate the image latents" + +"This is an auto pipeline block that works for image2image tasks." + +" - `ZImageVaeImageEncoderStep` is used when `image` is provided." + +" - if `image` is not provided, step will be skipped." + + +# auto_docstring +class ZImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image and image-to-image using ZImage. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + + Components: + text_encoder (`Qwen3Model`) tokenizer (`Qwen2Tokenizer`) guider (`ClassifierFreeGuidance`) vae + (`AutoencoderKL`) image_processor (`VaeImageProcessor`) transformer (`ZImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + max_sequence_length (`None`, *optional*, defaults to 512): + TODO: Add description. + image (`Image`, *optional*): + TODO: Add description. + height (`None`, *optional*): + TODO: Add description. + width (`None`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + num_images_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor | NoneType`): + TODO: Add description. + num_inference_steps (`None`): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + strength (`None`, *optional*, defaults to 0.6): + TODO: Add description. + **denoiser_input_fields (`None`, *optional*): + The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + The type of the output images, can be 'pil', 'np', 'pt' + + Outputs: + images (`list`): + Generated images. + """ + + block_classes = [ + ZImageTextEncoderStep, + ZImageAutoVaeImageEncoderStep, + ZImageAutoDenoiseStep, + ZImageVaeDecoderStep, + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image and image-to-image using ZImage." + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/modular_pipeline.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/modular_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d8e53a3639d481134c418c47a4fc0a9a95400f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/modular_pipelines/z_image/modular_pipeline.py @@ -0,0 +1,72 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import ZImageLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageModularPipeline( + ModularPipeline, + ZImageLoraLoaderMixin, +): + """ + A ModularPipeline for Z-Image. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "ZImageAutoBlocks" + + @property + def default_height(self): + return 1024 + + @property + def default_width(self): + return 1024 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor_spatial = 16 + if hasattr(self, "image_processor") and self.image_processor is not None: + vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor + return vae_scale_factor_spatial + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6557774c3b0ff0c0faf76de63f1916328bb8ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/auto_pipeline.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/auto_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a3942347a31a20869e29de1c27ef2ef076617f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/auto_pipeline.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/free_init_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/free_init_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..551d764d41f6cfa343a46271651610fa064000e8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/free_init_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/free_noise_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/free_noise_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b91462f0396a2b8eb4a2b43098284a0edfe9271d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/free_noise_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/onnx_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/onnx_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03ad601dadcdbf815129f982052953e0da0315e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/onnx_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/pipeline_flax_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/pipeline_flax_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da5178ef1c3ded3c460d15b7a0e37ed73acdb27 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/pipeline_flax_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/pipeline_loading_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/pipeline_loading_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5e5ce73ee1f0a39b3c34da64807892087b69e18 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/pipeline_loading_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/transformers_loading_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/transformers_loading_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5bd48ff84691a07ab6eaabf6c810a8af9efa2ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/__pycache__/transformers_loading_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2162b825e0a2d59aa79430fc0b66f664b6feccb5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_allegro"] = ["AllegroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_allegro import AllegroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/pipeline_allegro.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/pipeline_allegro.py new file mode 100644 index 0000000000000000000000000000000000000000..e54e9ed207396b09b4a9255938f40ace37a12218 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/pipeline_allegro.py @@ -0,0 +1,986 @@ +# Copyright 2025 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from ...models.embeddings import get_3d_rotary_pos_embed_allegro +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import AllegroPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoencoderKLAllegro, AllegroPipeline + >>> from diffusers.utils import export_to_video + + >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) + >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") + >>> pipe.enable_vae_tiling() + + >>> prompt = ( + ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " + ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " + ... "location might be a popular spot for docking fishing boats." + ... ) + >>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] + >>> export_to_video(video, "output.mp4", fps=15) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AllegroPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Allegro. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AllegroAutoEncoderKL3D`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`AllegroTransformer3DModel`]): + A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLAllegro, + transformer: AllegroTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 512, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_videos_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_videos_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if num_frames % 2 == 0: + num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal) + else: + num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1 + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + frames = self.vae.decode(latents).sample + frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] + return frames + + def _prepare_rotary_positional_embeddings( + self, + batch_size: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + ): + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + start, stop = (0, 0), (grid_height, grid_width) + freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=(start, stop), + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + interpolation_scale=( + self.transformer.config.interpolation_scale_t, + self.transformer.config.interpolation_scale_h, + self.transformer.config.interpolation_scale_w, + ), + device=device, + ) + + grid_t = grid_t.to(dtype=torch.long) + grid_h = grid_h.to(dtype=torch.long) + grid_w = grid_w.to(dtype=torch.long) + + pos = torch.cartesian_prod(grid_t, grid_h, grid_w) + pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous() + grid_t, grid_h, grid_w = pos + + return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 100, + timesteps: list[int] = None, + guidance_scale: float = 7.5, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + clean_caption: bool = True, + max_sequence_length: int = 512, + ) -> AllegroPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + num_frames: (`int`, *optional*, defaults to 88): + The number controls the generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to `512`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + + self.check_inputs( + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary embeddings + image_rotary_emb = self._prepare_rotary_positional_embeddings( + batch_size, height, width, latents.size(2), device + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + video = self.decode_latents(latents) + video = video[:, :, :num_frames, :height, :width] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AllegroPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..bf85a4954ce986c3cea771dcc1258596e857951f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/allegro/pipeline_output.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +import numpy as np +import PIL +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class AllegroPipelineOutput(BaseOutput): + r""" + Output class for Allegro pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4d07a426b54fabfcdf35bfb8e4486cd828b3b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/__init__.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, + ) + + _dummy_objects.update( + { + "AmusedPipeline": AmusedPipeline, + "AmusedImg2ImgPipeline": AmusedImg2ImgPipeline, + "AmusedInpaintPipeline": AmusedInpaintPipeline, + } + ) +else: + _import_structure["pipeline_amused"] = ["AmusedPipeline"] + _import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"] + _import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedPipeline, + ) + else: + from .pipeline_amused import AmusedPipeline + from .pipeline_amused_img2img import AmusedImg2ImgPipeline + from .pipeline_amused_inpaint import AmusedInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused.py new file mode 100644 index 0000000000000000000000000000000000000000..b23adf0d2152b8eafc2cb0d772fb498d5b161661 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused.py @@ -0,0 +1,342 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import is_torch_xla_available, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedPipeline + + >>> pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class AmusedPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.33.1" + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str] | str | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | None = None, + latents: torch.IntTensor | None = None, + prompt_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_encoder_hidden_states: torch.Tensor | None = None, + output_type="pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: tuple[int, int] = (0, 0), + temperature: int | tuple[int, int] | list[int] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.IntTensor`, *optional*): + Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image + generation. If not provided, the starting latents will be completely masked. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + micro_conditioning_crop_coord (`tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + temperature (`int | tuple[int, int, list[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if height is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + + if width is None: + width = self.transformer.config.sample_size * self.vae_scale_factor + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if latents is None: + latents = torch.full( + shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device + ) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + + num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep in enumerate(self.scheduler.timesteps): + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + output = latents + else: + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..79ebd96dedebefc295781cb9eb1c7bd09ab5459b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -0,0 +1,363 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import is_torch_xla_available, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedImg2ImgPipeline.from_pretrained( + ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "winter mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> image = pipe(prompt, input_image).images[0] + ``` +""" + + +class AmusedImg2ImgPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.33.1" + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str] | str | None = None, + image: PipelineImageInput = None, + strength: float = 0.5, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | None = None, + prompt_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_encoder_hidden_states: torch.Tensor | None = None, + output_type="pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: tuple[int, int] = (0, 0), + temperature: int | tuple[int, int] | list[int] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.5): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + micro_conditioning_crop_coord (`tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + temperature (`int | tuple[int, int, list[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negative_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + latents = self.scheduler.add_noise( + latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator + ) + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..55302401832c5683b0601c5690954472a44074b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -0,0 +1,394 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import is_torch_xla_available, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedInpaintPipeline.from_pretrained( + ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "fall mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> mask = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ... ) + ... .resize((512, 512)) + ... .convert("L") + ... ) + >>> pipe(prompt, input_image, mask).images[0].save("out.png") + ``` +""" + + +class AmusedInpaintPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + _last_supported_version = "0.33.1" + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + do_resize=True, + ) + self.scheduler.register_to_config(masking_schedule="linear") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str] | str | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + strength: float = 1.0, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | None = None, + prompt_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_encoder_hidden_states: torch.Tensor | None = None, + output_type="pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: tuple[int, int] = (0, 0), + temperature: int | tuple[int, int] | list[int] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.Tensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.Tensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See + https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + micro_conditioning_crop_coord (`tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of + https://huggingface.co/papers/2307.01952. + temperature (`int | tuple[int, int, list[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + + mask = self.mask_processor.preprocess( + mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) + latents[mask] = self.scheduler.config.mask_token_id + + starting_mask_ratio = mask.sum() / latents.numel() + + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + starting_mask_ratio=starting_mask_ratio, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d916abf2d85dfeee217479f14711dc2d33002ae6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"] + _import_structure["pipeline_animatediff_controlnet"] = ["AnimateDiffControlNetPipeline"] + _import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"] + _import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"] + _import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"] + _import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_animatediff import AnimateDiffPipeline + from .pipeline_animatediff_controlnet import AnimateDiffControlNetPipeline + from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline + from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline + from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline + from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline + from .pipeline_output import AnimateDiffPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b162a0d501f24dfb6988c37e00e648193489dea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a56171f77484ec818e560704df0b595b67be35 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c19501497ae9ffaa72dc43ed4dd2c8577c88787e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_sdxl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_sdxl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f0680507542337c8b74d0426457cc5fcd50168 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_sdxl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_sparsectrl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_sparsectrl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..032fff989145abd626dccc72044583e7b8c1d5a3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_sparsectrl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_video2video.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_video2video.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e384e1246c39591ae026c8090cab07d65cd7516 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_video2video.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_video2video_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_video2video_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0649cbd61153e002d7d1acf6e820b59cf8087eaf Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_animatediff_video2video_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5c620dcbdaa29e8032dfe80fe20bbf27eb0f63e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff.py new file mode 100644 index 0000000000000000000000000000000000000000..4d7477bc8754efd45ae2f1d0d0ce733e60256b32 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -0,0 +1,871 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") + >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter) + >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False) + >>> output = pipe(prompt="A corgi walking in the park") + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +class AnimateDiffPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + num_frames: int | None = 16, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + decode_chunk_size: int = 16, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..eb511129cc6fe17fda1c95612954737aa720994a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -0,0 +1,1118 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import ( + AutoencoderKL, + ControlNetModel, + ImageProjection, + MultiControlNetModel, + UNet2DConditionModel, + UNetMotionModel, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... AnimateDiffControlNetPipeline, + ... AutoencoderKL, + ... ControlNetModel, + ... MotionAdapter, + ... LCMScheduler, + ... ) + >>> from diffusers.utils import export_to_gif, load_video + + >>> # Additionally, you will need a preprocess videos before they can be used with the ControlNet + >>> # HF maintains just the right package for it: `pip install controlnet_aux` + >>> from controlnet_aux.processor import ZoeDetector + + >>> # Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file + >>> # Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained() + >>> controlnet = ControlNetModel.from_single_file("control_v11f1p_sd15_depth.pth", torch_dtype=torch.float16) + + >>> # We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3) + >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") + + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + >>> pipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... ).to(device="cuda", dtype=torch.float16) + >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + >>> pipe.load_lora_weights( + ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" + ... ) + >>> pipe.set_adapters(["lcm-lora"], [0.8]) + + >>> depth_detector = ZoeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" + ... ) + >>> conditioning_frames = [] + + >>> with pipe.progress_bar(total=len(video)) as progress_bar: + ... for frame in video: + ... conditioning_frames.append(depth_detector(frame)) + ... progress_bar.update() + + >>> prompt = "a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality" + >>> negative_prompt = "bad quality, worst quality" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=len(video), + ... num_inference_steps=10, + ... guidance_scale=2.0, + ... conditioning_frames=conditioning_frames, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> export_to_gif(video, "animatediff_controlnet.gif", fps=8) + ``` +""" + + +class AnimateDiffControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-video generation with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor | None = None, + image_encoder: CLIPVisionModelWithProjection | None = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + num_frames, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + video=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(video, list): + raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(video)}") + if len(video) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(video)=}") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(video, list) or not isinstance(video[0], list): + raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(video)=}") + if len(video[0]) != num_frames: + raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(video[0])=}") + if any(len(img) != len(video[0]) for img in video): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_video( + self, + video, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( + dtype=torch.float32 + ) + video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) + video_batch_size = video.shape[0] + + if video_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_videos_per_prompt + + video = video.repeat_interleave(repeat_by, dim=0) + video = video.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + video = torch.cat([video] * 2) + + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 16, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: PipelineImageInput | None = None, + conditioning_frames: list[PipelineImageInput] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`list[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple + ControlNets are specified, images must be passed as a list such that each element of the list can be + correctly batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + num_frames=num_frames, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=conditioning_frames, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + if isinstance(controlnet, ControlNetModel): + conditioning_frames = self.prepare_video( + video=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + cond_prepared_videos = [] + for frame_ in conditioning_frames: + prepared_video = self.prepare_video( + video=frame_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + cond_prepared_videos.append(prepared_video) + conditioning_frames = cond_prepared_videos + else: + assert False + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + control_model_input = torch.transpose(control_model_input, 1, 2) + control_model_input = control_model_input.reshape( + (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=conditioning_frames, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..68ce7c92896a2e6dd7604e22442f6f36254ae134 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -0,0 +1,1307 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MotionAdapter, UNet2DConditionModel, UNetMotionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.models import MotionAdapter + >>> from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> adapter = MotionAdapter.from_pretrained( + ... "a-r-r-o-w/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16 + ... ) + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... clip_sample=False, + ... timestep_spacing="linspace", + ... beta_schedule="linear", + ... steps_offset=1, + ... ) + >>> pipe = AnimateDiffSDXLPipeline.from_pretrained( + ... model_id, + ... motion_adapter=adapter, + ... scheduler=scheduler, + ... torch_dtype=torch.float16, + ... variant="fp16", + ... ).to("cuda") + + >>> # enable memory savings + >>> pipe.enable_vae_slicing() + >>> pipe.enable_vae_tiling() + + >>> output = pipe( + ... prompt="a panda surfing in the ocean, realistic, high quality", + ... negative_prompt="low quality, worst quality", + ... num_inference_steps=20, + ... guidance_scale=8, + ... width=1024, + ... height=1024, + ... num_frames=16, + ... ) + + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffSDXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + FreeInitMixin, +): + r""" + Pipeline for text-to-video generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + num_frames: int = 16, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_frames: + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated video. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated video. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the + `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.AnimateDiffPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 8. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_videos_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + + # 9. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # 10. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3eb8fc3c0313df4ec3127af8adcc16cc5feb4a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -0,0 +1,1023 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AnimateDiffSparseControlNetPipeline + >>> from diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel + >>> from diffusers.schedulers import DPMSolverMultistepScheduler + >>> from diffusers.utils import export_to_gif, load_image + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-3" + >>> controlnet_id = "guoyww/animatediff-sparsectrl-scribble" + >>> lora_adapter_id = "guoyww/animatediff-motion-lora-v1-5-3" + >>> vae_id = "stabilityai/sd-vae-ft-mse" + >>> device = "cuda" + + >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device) + >>> controlnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device) + >>> vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device) + >>> scheduler = DPMSolverMultistepScheduler.from_pretrained( + ... model_id, + ... subfolder="scheduler", + ... beta_schedule="linear", + ... algorithm_type="dpmsolver++", + ... use_karras_sigmas=True, + ... ) + >>> pipe = AnimateDiffSparseControlNetPipeline.from_pretrained( + ... model_id, + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... scheduler=scheduler, + ... torch_dtype=torch.float16, + ... ).to(device) + >>> pipe.load_lora_weights(lora_adapter_id, adapter_name="motion_lora") + >>> pipe.fuse_lora(lora_scale=1.0) + + >>> prompt = "an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality" + >>> negative_prompt = "low quality, worst quality, letterboxed" + + >>> image_files = [ + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png", + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png", + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png", + ... ] + >>> condition_frame_indices = [0, 8, 15] + >>> conditioning_frames = [load_image(img_file) for img_file in image_files] + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=25, + ... conditioning_frames=conditioning_frames, + ... controlnet_conditioning_scale=1.0, + ... controlnet_frame_indices=condition_frame_indices, + ... generator=torch.Generator().manual_seed(1337), + ... ).frames[0] + >>> export_to_gif(video, "output.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class AnimateDiffSparseControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + FromSingleFileMixin, +): + r""" + Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls + to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + controlnet: SparseControlNetModel, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + controlnet_conditioning_scale: float = 1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + # check `image` + if ( + isinstance(self.controlnet, SparseControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, SparseControlNetModel) + ): + if isinstance(image, list): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, SparseControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, SparseControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image(self, image, width, height, device, dtype): + image = self.control_image_processor.preprocess(image, height=height, width=width) + controlnet_images = image.unsqueeze(0).to(device, dtype) + batch_size, num_frames, channels, height, width = controlnet_images.shape + + # TODO: remove below line + assert controlnet_images.min() >= 0 and controlnet_images.max() <= 1 + + if self.controlnet.use_simplified_condition_embedding: + controlnet_images = controlnet_images.reshape(batch_size * num_frames, channels, height, width) + controlnet_images = 2 * controlnet_images - 1 + conditioning_frames = retrieve_latents(self.vae.encode(controlnet_images)) * self.vae.config.scaling_factor + conditioning_frames = conditioning_frames.reshape( + batch_size, num_frames, 4, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + else: + conditioning_frames = controlnet_images + + conditioning_frames = conditioning_frames.permute(0, 2, 1, 3, 4) # [b, c, f, h, w] + return conditioning_frames + + def prepare_sparse_control_conditioning( + self, + conditioning_frames: torch.Tensor, + num_frames: int, + controlnet_frame_indices: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert conditioning_frames.shape[2] >= len(controlnet_frame_indices) + + batch_size, channels, _, height, width = conditioning_frames.shape + controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond[:, :, controlnet_frame_indices] = conditioning_frames[:, :, : len(controlnet_frame_indices)] + controlnet_cond_mask[:, :, controlnet_frame_indices] = 1 + + return controlnet_cond, controlnet_cond_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + conditioning_frames: list[PipelineImageInput] | None = None, + output_type: str = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + controlnet_frame_indices: list[int] = [0], + guess_mode: bool = False, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`list[PipelineImageInput]`, *optional*): + The SparseControlNet input to provide guidance to the `unet` for generation. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_frame_indices (`list[int]`): + The indices where the conditioning frames must be applied for generation. Multiple frames can be + provided to guide the model to generate similar structure outputs, where the `unet` can + "fill-in-the-gaps" for interpolation videos, or a single frame could be provided for general expected + structure. Must have the same length as `conditioning_frames`. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + image=conditioning_frames, + controlnet_conditioning_scale=controlnet_conditioning_scale, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, SparseControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 4. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 5. Prepare controlnet conditioning + conditioning_frames = self.prepare_image(conditioning_frames, width, height, device, controlnet.dtype) + controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( + conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 10. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer SparseControlNetModel only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_cond, + conditioning_mask=controlnet_cond_mask, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 11. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 12. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7cd21fc25d413c4df57819280d29a3345aeb97 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -0,0 +1,1067 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import imageio + >>> import requests + >>> import torch + >>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter + >>> from diffusers.utils import export_to_gif + >>> from io import BytesIO + >>> from PIL import Image + + >>> adapter = MotionAdapter.from_pretrained( + ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16 + ... ) + >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter + ... ).to("cuda") + >>> pipe.scheduler = DDIMScheduler( + ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace" + ... ) + + + >>> def load_video(file_path: str): + ... images = [] + + ... if file_path.startswith(("http://", "https://")): + ... # If the file_path is a URL + ... response = requests.get(file_path) + ... response.raise_for_status() + ... content = BytesIO(response.content) + ... vid = imageio.get_reader(content) + ... else: + ... # Assuming it's a local file path + ... vid = imageio.get_reader(file_path) + + ... for frame in vid: + ... pil_image = Image.fromarray(frame) + ... images.append(pil_image) + + ... return images + + + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif" + ... ) + >>> output = pipe( + ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5 + ... ) + >>> frames = output.frames[0] + >>> export_to_gif(frames, "animation.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffVideoToVideoPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + FromSingleFileMixin, +): + r""" + Pipeline for video-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), decode_chunk_size): + batch_video = video[i : i + decode_chunk_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + height, + width, + video=None, + latents=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + video: torch.Tensor | None = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: int | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + decode_chunk_size: int = 16, + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + video = video.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + init_latents = [ + self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) + for i in range(batch_size) + ] + else: + init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video] + + init_latents = torch.cat(init_latents, dim=0) + + # restore vae to original dtype + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + error_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" + ) + raise ValueError(error_message) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) + else: + if shape != latents.shape: + # [B, C, F, H, W] + raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + + latents = latents.to(device, dtype=dtype) + + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + video: list[list[PipelineImageInput]] = None, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + enforce_inference_steps: bool = False, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 7.5, + strength: float = 0.8, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`list[PipelineImageInput]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=video, + latents=latents, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.dtype + + # 3. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + + # 4. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(device=device, dtype=dtype) + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, + ) + + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + num_inference_steps = len(timesteps) + # make sure to readjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 9. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 10. Post-processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..56ed5e23c1dbefc26f3ebca650ca7c55b3514275 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -0,0 +1,1355 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import ( + AutoencoderKL, + ControlNetModel, + ImageProjection, + MultiControlNetModel, + UNet2DConditionModel, + UNetMotionModel, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import AnimateDiffPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from tqdm.auto import tqdm + + >>> from diffusers import AnimateDiffVideoToVideoControlNetPipeline + >>> from diffusers.utils import export_to_gif, load_video + >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 + ... ) + >>> motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM") + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + + >>> pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V5.1_noVAE", + ... motion_adapter=motion_adapter, + ... controlnet=controlnet, + ... vae=vae, + ... ).to(device="cuda", dtype=torch.float16) + + >>> pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + >>> pipe.load_lora_weights( + ... "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora" + ... ) + >>> pipe.set_adapters(["lcm-lora"], [0.8]) + + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif" + ... ) + >>> video = [frame.convert("RGB") for frame in video] + + >>> from controlnet_aux.processor import OpenposeDetector + + >>> open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda") + >>> for frame in tqdm(video): + ... conditioning_frames.append(open_pose(frame)) + + >>> prompt = "astronaut in space, dancing" + >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" + + >>> strength = 0.8 + >>> with torch.inference_mode(): + ... video = pipe( + ... video=video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=10, + ... guidance_scale=2.0, + ... controlnet_conditioning_scale=0.75, + ... conditioning_frames=conditioning_frames, + ... strength=strength, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> video = [frame.resize(conditioning_frames[0].size) for frame in video] + >>> export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AnimateDiffVideoToVideoControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + FromSingleFileMixin, +): + r""" + Pipeline for video-to-video generation with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]` or `tuple[ControlNetModel]` or `MultiControlNetModel`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_video_processor = VideoProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.encode_video + def encode_video(self, video, generator, decode_chunk_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), decode_chunk_size): + batch_video = video[i : i + decode_chunk_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + height, + width, + video=None, + conditioning_frames=None, + latents=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, (str, list, dict)): + raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + + num_frames = len(video) if latents is None else latents.shape[2] + + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(conditioning_frames, list): + raise TypeError( + f"For single controlnet, `image` must be of type `list` but got {type(conditioning_frames)}" + ) + if len(conditioning_frames) != num_frames: + raise ValueError(f"Excepted image to have length {num_frames} but got {len(conditioning_frames)=}") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(conditioning_frames, list) or not isinstance(conditioning_frames[0], list): + raise TypeError( + f"For multiple controlnets: `image` must be type list of lists but got {type(conditioning_frames)=}" + ) + if len(conditioning_frames[0]) != num_frames: + raise ValueError( + f"Expected length of image sublist as {num_frames} but got {len(conditioning_frames)=}" + ) + if any(len(img) != len(conditioning_frames[0]) for img in conditioning_frames): + raise ValueError("All conditioning frame batches for multicontrolnet must be same size") + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.prepare_latents + def prepare_latents( + self, + video: torch.Tensor | None = None, + height: int = 64, + width: int = 64, + num_channels_latents: int = 4, + batch_size: int = 1, + timestep: int | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + decode_chunk_size: int = 16, + add_noise: bool = False, + ) -> torch.Tensor: + num_frames = video.shape[1] if latents is None else latents.shape[2] + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + video = video.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + init_latents = [ + self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0) + for i in range(batch_size) + ] + else: + init_latents = [self.encode_video(vid, generator, decode_chunk_size).unsqueeze(0) for vid in video] + + init_latents = torch.cat(init_latents, dim=0) + + # restore vae to original dtype + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + error_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Please make sure to update your script to pass as many initial images as text prompts" + ) + raise ValueError(error_message) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4) + else: + if shape != latents.shape: + # [B, C, F, H, W] + raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}") + + latents = latents.to(device, dtype=dtype) + + if add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(latents, noise, timestep) + + return latents + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_controlnet.AnimateDiffControlNetPipeline.prepare_video + def prepare_conditioning_frames( + self, + video, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + video = self.control_video_processor.preprocess_video(video, height=height, width=width).to( + dtype=torch.float32 + ) + video = video.permute(0, 2, 1, 3, 4).flatten(0, 1) + video_batch_size = video.shape[0] + + if video_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_videos_per_prompt + + video = video.repeat_interleave(repeat_by, dim=0) + video = video.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + video = torch.cat([video] * 2) + + return video + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + video: list[list[PipelineImageInput]] = None, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + enforce_inference_steps: bool = False, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 7.5, + strength: float = 0.8, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + conditioning_frames: list[PipelineImageInput] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + decode_chunk_size: int = 16, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`list[PipelineImageInput]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + conditioning_frames (`list[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple + ControlNets are specified, images must be passed as a list such that each element of the list can be + correctly batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + decode_chunk_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. + + Examples: + + Returns: + [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + video=video, + conditioning_frames=conditioning_frames, + latents=latents, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, (str, dict)): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.dtype + + # 3. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if not enforce_inference_steps: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + else: + denoising_inference_steps = int(num_inference_steps / strength) + timesteps, denoising_inference_steps = retrieve_timesteps( + self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps = timesteps[-num_inference_steps:] + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + + # 4. Prepare latent variables + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + # Move the number of frames before the number of channels. + video = video.permute(0, 2, 1, 3, 4) + video = video.to(device=device, dtype=dtype) + + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + video=video, + height=height, + width=width, + num_channels_latents=num_channels_latents, + batch_size=batch_size * num_videos_per_prompt, + timestep=latent_timestep, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + decode_chunk_size=decode_chunk_size, + add_noise=enforce_inference_steps, + ) + + # 5. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + num_frames = latents.shape[2] + if self.free_noise_enabled: + prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise( + prompt=prompt, + num_frames=num_frames, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + else: + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + # 6. Prepare IP-Adapter embeddings + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 7. Prepare ControlNet conditions + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + if isinstance(controlnet, ControlNetModel): + conditioning_frames = self.prepare_conditioning_frames( + video=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + cond_prepared_videos = [] + for frame_ in conditioning_frames: + prepared_video = self.prepare_conditioning_frames( + video=frame_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + cond_prepared_videos.append(prepared_video) + conditioning_frames = cond_prepared_videos + else: + assert False + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + num_inference_steps = len(timesteps) + # make sure to readjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 10. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + control_model_input = torch.transpose(control_model_input, 1, 2) + control_model_input = control_model_input.reshape( + (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=conditioning_frames, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 11. Post-processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 12. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..436a20f455bfb9ee6851a537737c9529fd5cff9a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/animatediff/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class AnimateDiffPipelineOutput(BaseOutput): + r""" + Output class for AnimateDiff pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a002b4aa72e0a180c7042c406667d37122d6e4cc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) + + _dummy_objects.update({"AudioLDMPipeline": AudioLDMPipeline}) +else: + _import_structure["pipeline_audioldm"] = ["AudioLDMPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AudioLDMPipeline, + ) + + else: + from .pipeline_audioldm import AudioLDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm/pipeline_audioldm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm/pipeline_audioldm.py new file mode 100644 index 0000000000000000000000000000000000000000..357c3582b21cf838e3b8812853b3c4ba1814c0c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -0,0 +1,558 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AudioLDMPipeline + >>> import torch + >>> import scipy + + >>> repo_id = "cvssp/audioldm-s-full-v2" + >>> pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" + >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] + + >>> # save the audio sample as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ``` +""" + + +class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-audio generation using AudioLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapTextModelWithProjection`]): + Frozen text-encoder (`ClapTextModelWithProjection`, specifically the + [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + A [`~transformers.RobertaTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan`. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapTextModelWithProjection, + tokenizer: RobertaTokenizer | RobertaTokenizerFast, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + prompt_embeds = F.normalize(prompt_embeds, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds.text_embeds + # additional L_2 normalization over each hidden-state + negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + return mel_spectrogram + + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + audio_length_in_s: float | None = None, + num_inference_steps: int = 10, + guidance_scale: float = 2.5, + negative_prompt: str | list[str] | None = None, + num_waveforms_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + output_type: str | None = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_length_in_s (`int`, *optional*, defaults to 5.12): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 10): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.5): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. + + Examples: + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + mel_spectrogram = self.decode_latents(latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23cd0e44f89217b8391d0ce236070271db9aaf83 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_audioldm2"] = ["AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel"] + _import_structure["pipeline_audioldm2"] = ["AudioLDM2Pipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + from .pipeline_audioldm2 import AudioLDM2Pipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/modeling_audioldm2.py new file mode 100644 index 0000000000000000000000000000000000000000..09aa0ad1700354695cd6ae5f5b5e5a655251e7a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/modeling_audioldm2.py @@ -0,0 +1,1411 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...models.activations import get_activation +from ...models.attention import AttentionMixin +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, +) +from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.modeling_utils import ModelMixin +from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ...models.transformers.transformer_2d import Transformer2DModel +from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D +from ...models.unets.unet_2d_condition import UNet2DConditionOutput +from ...utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token): + batch_size = hidden_states.shape[0] + + if attention_mask is not None: + # Add two more steps to attn mask + new_attn_mask_step = attention_mask.new_ones((batch_size, 1)) + attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1) + + # Add the SOS / EOS tokens at the start / end of the sequence respectively + sos_token = sos_token.expand(batch_size, 1, -1) + eos_token = eos_token.expand(batch_size, 1, -1) + hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1) + return hidden_states, attention_mask + + +@dataclass +class AudioLDM2ProjectionModelOutput(BaseOutput): + """ + Args: + Class for AudioLDM2 projection layer's outputs. + hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text + encoders and subsequently concatenating them together. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + """ + + hidden_states: torch.Tensor + attention_mask: torch.LongTensor | None = None + + +class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned + embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with + `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the first text encoder (CLAP). + text_encoder_1_dim (`int`): + Dimensionality of the text embeddings from the second text encoder (T5 or VITS). + langauge_model_dim (`int`): + Dimensionality of the text embeddings from the language model (GPT2). + """ + + @register_to_config + def __init__( + self, + text_encoder_dim, + text_encoder_1_dim, + langauge_model_dim, + use_learned_position_embedding=None, + max_seq_length=None, + ): + super().__init__() + # additional projection layers for each text encoder + self.projection = nn.Linear(text_encoder_dim, langauge_model_dim) + self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim) + + # learnable SOS / EOS token embeddings for each text encoder + self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim)) + self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim)) + + self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) + self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim)) + + self.use_learned_position_embedding = use_learned_position_embedding + + # learable positional embedding for vits encoder + if self.use_learned_position_embedding is not None: + self.learnable_positional_embedding = torch.nn.Parameter( + torch.zeros((1, text_encoder_1_dim, max_seq_length)) + ) + + def forward( + self, + hidden_states: torch.Tensor | None = None, + hidden_states_1: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + attention_mask_1: torch.LongTensor | None = None, + ): + hidden_states = self.projection(hidden_states) + hidden_states, attention_mask = add_special_tokens( + hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed + ) + + # Add positional embedding for Vits hidden state + if self.use_learned_position_embedding is not None: + hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1) + + hidden_states_1 = self.projection_1(hidden_states_1) + hidden_states_1, attention_mask_1 = add_special_tokens( + hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1 + ) + + # concatenate clap and t5 text encoding + hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1) + + # concatenate attention masks + if attention_mask is None and attention_mask_1 is not None: + attention_mask = attention_mask_1.new_ones((hidden_states[:2])) + elif attention_mask is not None and attention_mask_1 is None: + attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2])) + + if attention_mask is not None and attention_mask_1 is not None: + attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1) + else: + attention_mask = None + + return AudioLDM2ProjectionModelOutput( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + + +class AudioLDM2UNet2DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional + self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up + to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention (`bool` or `tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: str = "UNetMidBlock2DCrossAttn", + up_block_types: tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: bool | tuple[bool] = False, + block_out_channels: tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int | tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int | tuple[int] = 1280, + transformer_layers_per_block: int | tuple[int] = 1, + attention_head_dim: int | tuple[int] = 8, + num_attention_heads: int | tuple[int] | None = None, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + class_embeddings_concat: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.") + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + else: + raise ValueError( + f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2." + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + class_labels: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ) -> UNet2DConditionOutput | tuple: + r""" + The [`AudioLDM2UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + encoder_hidden_states_1 (`torch.Tensor`, *optional*): + A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be + used to condition the model on a different set of embeddings to `encoder_hidden_states`. + encoder_attention_mask_1 (`torch.Tensor`, *optional*): + A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`. + If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if encoder_attention_mask_1 is not None: + encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0 + encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + transformer_layers_per_block=1, + num_attention_heads=None, + resnet_groups=None, + cross_attention_dim=None, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ): + output_states = () + num_layers = len(self.resnets) + num_attention_per_layer = len(self.attentions) // num_layers + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(num_layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + )[0] + else: + hidden_states = self.resnets[i](hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + use_linear_projection=False, + upcast_attention=False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1) + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(len(self.resnets[1:])): + if torch.is_grad_enabled() and self.gradient_checkpointing: + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + )[0] + hidden_states = self._gradient_checkpointing_func(self.resnets[i + 1], hidden_states, temb) + else: + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = self.resnets[i + 1](hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) + if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4: + raise ValueError( + "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention " + f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}" + ) + self.cross_attention_dim = cross_attention_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + for j in range(len(cross_attention_dim)): + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim[j], + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + double_self_attention=True if cross_attention_dim[j] is None else False, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + encoder_hidden_states_1: torch.Tensor | None = None, + encoder_attention_mask_1: torch.Tensor | None = None, + ): + num_layers = len(self.resnets) + num_attention_per_layer = len(self.attentions) // num_layers + + encoder_hidden_states_1 = ( + encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states + ) + encoder_attention_mask_1 = ( + encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask + ) + + for i in range(num_layers): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self._gradient_checkpointing_func( + self.attentions[i * num_attention_per_layer + idx], + hidden_states, + forward_encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + forward_encoder_attention_mask, + )[0] + else: + hidden_states = self.resnets[i](hidden_states, temb) + for idx, cross_attention_dim in enumerate(self.cross_attention_dim): + if cross_attention_dim is not None and idx <= 1: + forward_encoder_hidden_states = encoder_hidden_states + forward_encoder_attention_mask = encoder_attention_mask + elif cross_attention_dim is not None and idx > 1: + forward_encoder_hidden_states = encoder_hidden_states_1 + forward_encoder_attention_mask = encoder_attention_mask_1 + else: + forward_encoder_hidden_states = None + forward_encoder_attention_mask = None + hidden_states = self.attentions[i * num_attention_per_layer + idx]( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=forward_encoder_hidden_states, + encoder_attention_mask=forward_encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/pipeline_audioldm2.py new file mode 100644 index 0000000000000000000000000000000000000000..b023974a33dd24414246ada048525f1bf6d69b0d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -0,0 +1,1125 @@ +# Copyright 2025 CVSSP, ByteDance and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + ClapFeatureExtractor, + ClapModel, + GPT2LMHeadModel, + RobertaTokenizer, + RobertaTokenizerFast, + SpeechT5HifiGan, + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, + VitsModel, + VitsTokenizer, +) + +from ...models import AutoencoderKL +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + is_librosa_available, + logging, + replace_example_docstring, +) +from ...utils.import_utils import is_transformers_version +from ...utils.torch_utils import empty_device_cache, randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + + +if is_librosa_available(): + import librosa + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> from diffusers import AudioLDM2Pipeline + + >>> repo_id = "cvssp/audioldm2" + >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_length_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> # save the best audio sample (index 0) as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0]) + ``` + ``` + #Using AudioLDM2 for Text To Speech + >>> import scipy + >>> import torch + >>> from diffusers import AudioLDM2Pipeline + + >>> repo_id = "anhnct/audioldm2_gigaspeech" + >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "A female reporter is speaking" + >>> transcript = "wish you have a good day" + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... transcription=transcript, + ... num_inference_steps=200, + ... audio_length_in_s=10.0, + ... num_waveforms_per_prompt=2, + ... generator=generator, + ... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS + ... ).audios + + >>> # save the best audio sample (index 0) as a .wav file + >>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0]) + ``` +""" + + +def prepare_inputs_for_generation( + inputs_embeds, + attention_mask=None, + past_key_values=None, + **kwargs, +): + if past_key_values is not None: + # only last token for inputs_embeds if past is defined in kwargs + inputs_embeds = inputs_embeds[:, -1:] + + return { + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + } + + +class AudioLDM2Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using AudioLDM2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapModel`]): + First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model + [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection), + specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The + text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to + rank generated waveforms against the text prompt by computing similarity scores. + text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]): + Second frozen text-encoder. AudioLDM2 uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use + for TTS. AudioLDM2 uses the encoder of + [Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel). + projection_model ([`AudioLDM2ProjectionModel`]): + A trained model used to linearly project the hidden-states from the first and second text encoder models + and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are + concatenated to give the input to the language model. A Learned Position Embedding for the Vits + hidden-states + language_model ([`~transformers.GPT2Model`]): + An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected + outputs from the two text encoders. + tokenizer ([`~transformers.RobertaTokenizer`]): + Tokenizer to tokenize text for the first frozen text-encoder. + tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]): + Tokenizer to tokenize text for the second frozen text-encoder. + feature_extractor ([`~transformers.ClapFeatureExtractor`]): + Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapModel, + text_encoder_2: T5EncoderModel | VitsModel, + projection_model: AudioLDM2ProjectionModel, + language_model: GPT2LMHeadModel, + tokenizer: RobertaTokenizer | RobertaTokenizerFast, + tokenizer_2: T5Tokenizer | T5TokenizerFast | VitsTokenizer, + feature_extractor: ClapFeatureExtractor, + unet: AudioLDM2UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + projection_model=projection_model, + language_model=language_model, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = "cuda"): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + device_type = torch_device.type + device_str = device_type + if gpu_id or torch_device.index: + device_str = f"{device_str}:{gpu_id or torch_device.index}" + device = torch.device(device_str) + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + empty_device_cache(device.type) + + model_sequence = [ + self.text_encoder.text_model, + self.text_encoder.text_projection, + self.text_encoder_2, + self.projection_model, + self.language_model, + self.unet, + self.vae, + self.vocoder, + self.text_encoder, + ] + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + def generate_language_model( + self, + inputs_embeds: torch.Tensor = None, + max_new_tokens: int = 8, + **model_kwargs, + ): + """ + + Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs. + + Parameters: + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + The sequence used as a prompt for the generation. + max_new_tokens (`int`): + Number of new tokens to generate. + model_kwargs (`dict[str, Any]`, *optional*): + Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward` + function of the model. + + Return: + `inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + The sequence of generated hidden-states. + """ + cache_position_kwargs = {} + if is_transformers_version("<", "4.52.1"): + cache_position_kwargs["input_ids"] = inputs_embeds + else: + cache_position_kwargs["seq_length"] = inputs_embeds.shape[0] + cache_position_kwargs["device"] = ( + self.language_model.device if getattr(self, "language_model", None) is not None else self.device + ) + cache_position_kwargs["model_kwargs"] = model_kwargs + max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens + model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs) + + for _ in range(max_new_tokens): + # prepare model inputs + model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs) + + # forward pass to get next hidden states + output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True) + + next_hidden_states = output.hidden_states[-1] + + # Update the model input + inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1) + + # Update generated hidden states, model inputs, and length for next step + model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs) + + return inputs_embeds[:, -max_new_tokens:, :] + + def encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + transcription=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + generated_prompt_embeds: torch.Tensor | None = None, + negative_generated_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + negative_attention_mask: torch.LongTensor | None = None, + max_new_tokens: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + transcription (`str` or `list[str]`): + transcription of text to speech + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input + argument. + negative_generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention + mask will be computed from `negative_prompt` input argument. + max_new_tokens (`int`, *optional*, defaults to None): + The number of new tokens to generate with the GPT2 language model. + Returns: + prompt_embeds (`torch.Tensor`): + Text embeddings from the Flan T5 model. + attention_mask (`torch.LongTensor`): + Attention mask to be applied to the `prompt_embeds`. + generated_prompt_embeds (`torch.Tensor`): + Text embeddings generated from the GPT2 language model. + + Example: + + ```python + >>> import scipy + >>> import torch + >>> from diffusers import AudioLDM2Pipeline + + >>> repo_id = "cvssp/audioldm2" + >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # Get text embedding vectors + >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt( + ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs", + ... device="cuda", + ... do_classifier_free_guidance=True, + ... ) + + >>> # Pass text embeddings to pipeline for text-conditional audio generation + >>> audio = pipe( + ... prompt_embeds=prompt_embeds, + ... attention_mask=attention_mask, + ... generated_prompt_embeds=generated_prompt_embeds, + ... num_inference_steps=200, + ... audio_length_in_s=10.0, + ... ).audios[0] + + >>> # save generated audio sample + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ```""" + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] + is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel) + + if is_vits_text_encoder: + text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder] + else: + text_encoders = [self.text_encoder, self.text_encoder_2] + + if prompt_embeds is None: + prompt_embeds_list = [] + attention_mask_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + use_prompt = isinstance( + tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast) + ) + text_inputs = tokenizer( + prompt if use_prompt else transcription, + padding="max_length" + if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)) + else True, + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + f"The following part of your input was truncated because {text_encoder.config.model_type} can " + f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + if text_encoder.config.model_type == "clap": + prompt_embeds = text_encoder.get_text_features( + text_input_ids, + attention_mask=attention_mask, + ) + # Extract the pooler output if it's a BaseModelOutputWithPooling (Transformers v5+) + # otherwise use it directly (Transformers v4) + if hasattr(prompt_embeds, "pooler_output"): + prompt_embeds = prompt_embeds.pooler_output + # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) + prompt_embeds = prompt_embeds[:, None, :] + # make sure that we attend to this single hidden-state + attention_mask = attention_mask.new_ones((batch_size, 1)) + elif is_vits_text_encoder: + # Add end_token_id and attention mask in the end of sequence phonemes + for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask): + for idx, phoneme_id in enumerate(text_input_id): + if phoneme_id == 0: + text_input_id[idx] = 182 + text_attention_mask[idx] = 1 + break + prompt_embeds = text_encoder( + text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1) + ) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds_list.append(prompt_embeds) + attention_mask_list.append(attention_mask) + + projection_output = self.projection_model( + hidden_states=prompt_embeds_list[0], + hidden_states_1=prompt_embeds_list[1], + attention_mask=attention_mask_list[0], + attention_mask_1=attention_mask_list[1], + ) + projected_prompt_embeds = projection_output.hidden_states + projected_attention_mask = projection_output.attention_mask + + generated_prompt_embeds = self.generate_language_model( + projected_prompt_embeds, + attention_mask=projected_attention_mask, + max_new_tokens=max_new_tokens, + ) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + attention_mask = ( + attention_mask.to(device=device) + if attention_mask is not None + else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device) + ) + generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device) + + bs_embed, seq_len, hidden_size = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size) + + # duplicate attention mask for each generation per prompt + attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len) + + bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape + # duplicate generated embeddings for each generation per prompt, using mps friendly method + generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + generated_prompt_embeds = generated_prompt_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + negative_attention_mask_list = [] + max_length = prompt_embeds.shape[1] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=tokenizer.model_max_length + if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)) + else max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + if text_encoder.config.model_type == "clap": + negative_prompt_embeds = text_encoder.get_text_features( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + # Extract the pooler output if it's a BaseModelOutputWithPooling (Transformers v5+) + # otherwise use it directly (Transformers v4) + if hasattr(negative_prompt_embeds, "pooler_output"): + negative_prompt_embeds = negative_prompt_embeds.pooler_output + # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) + negative_prompt_embeds = negative_prompt_embeds[:, None, :] + # make sure that we attend to this single hidden-state + negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1)) + elif is_vits_text_encoder: + negative_prompt_embeds = torch.zeros( + batch_size, + tokenizer.model_max_length, + text_encoder.config.hidden_size, + ).to(dtype=self.text_encoder_2.dtype, device=device) + negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to( + dtype=self.text_encoder_2.dtype, device=device + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_attention_mask_list.append(negative_attention_mask) + + projection_output = self.projection_model( + hidden_states=negative_prompt_embeds_list[0], + hidden_states_1=negative_prompt_embeds_list[1], + attention_mask=negative_attention_mask_list[0], + attention_mask_1=negative_attention_mask_list[1], + ) + negative_projected_prompt_embeds = projection_output.hidden_states + negative_projected_attention_mask = projection_output.attention_mask + + negative_generated_prompt_embeds = self.generate_language_model( + negative_projected_prompt_embeds, + attention_mask=negative_projected_attention_mask, + max_new_tokens=max_new_tokens, + ) + + if do_classifier_free_guidance: + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_attention_mask = ( + negative_attention_mask.to(device=device) + if negative_attention_mask is not None + else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device) + ) + negative_generated_prompt_embeds = negative_generated_prompt_embeds.to( + dtype=self.language_model.dtype, device=device + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1) + + # duplicate unconditional attention mask for each generation per prompt + negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt) + negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len) + + # duplicate unconditional generated embeddings for each generation per prompt + seq_len = negative_generated_prompt_embeds.shape[1] + negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1) + negative_generated_prompt_embeds = negative_generated_prompt_embeds.view( + batch_size * num_waveforms_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds]) + + return prompt_embeds, attention_mask, generated_prompt_embeds + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype): + if not is_librosa_available(): + logger.info( + "Automatic scoring of the generated audio waveforms against the input prompt text requires the " + "`librosa` package to resample the generated waveforms. Returning the audios in the order they were " + "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`." + ) + return audio + inputs = self.tokenizer(text, return_tensors="pt", padding=True) + resampled_audio = librosa.resample( + audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate + ) + inputs["input_features"] = self.feature_extractor( + list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate + ).input_features.type(dtype) + inputs = inputs.to(device) + + # compute the audio-text similarity score using the CLAP model + logits_per_text = self.text_encoder(**inputs).logits_per_text + # sort by the highest matching generations per prompt + indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt] + audio = torch.index_select(audio, 0, indices.reshape(-1).cpu()) + return audio + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + transcription=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + generated_prompt_embeds=None, + negative_generated_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave " + "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None: + raise ValueError( + "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that" + "both arguments are specified" + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if transcription is None: + if self.text_encoder_2.config.model_type == "vits": + raise ValueError("Cannot forward without transcription. Please make sure to have transcription") + elif transcription is not None and ( + not isinstance(transcription, str) and not isinstance(transcription, list) + ): + raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}") + + if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None: + if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape: + raise ValueError( + "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when " + f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != " + f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}." + ) + if ( + negative_attention_mask is not None + and negative_attention_mask.shape != negative_prompt_embeds.shape[:2] + ): + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + transcription: str | list[str] = None, + audio_length_in_s: float | None = None, + num_inference_steps: int = 200, + guidance_scale: float = 3.5, + negative_prompt: str | list[str] | None = None, + num_waveforms_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + generated_prompt_embeds: torch.Tensor | None = None, + negative_generated_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + negative_attention_mask: torch.LongTensor | None = None, + max_new_tokens: int | None = None, + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + output_type: str | None = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + transcription (`str` or `list[str]`, *optional*):\ + The transcript for text to speech. + audio_length_in_s (`int`, *optional*, defaults to 10.24): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 200): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.5): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic + scoring is performed between the generated outputs and the text prompt. This scoring ranks the + generated waveforms based on their cosine similarity with the text input in the joint text-audio + embedding space. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input + argument. + negative_generated_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention + mask will be computed from `negative_prompt` input argument. + max_new_tokens (`int`, *optional*, defaults to None): + Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will + be taken from the config of the model. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + transcription, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + generated_prompt_embeds, + negative_generated_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + transcription, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + generated_prompt_embeds=generated_prompt_embeds, + negative_generated_prompt_embeds=negative_generated_prompt_embeds, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, + max_new_tokens=max_new_tokens, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=generated_prompt_embeds, + encoder_hidden_states_1=prompt_embeds, + encoder_attention_mask_1=attention_mask, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + self.maybe_free_model_hooks() + + # 8. Post-processing + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + # 9. Automatic scoring + if num_waveforms_per_prompt > 1 and prompt is not None: + audio = self.score_waveforms( + text=prompt, + audio=audio, + num_waveforms_per_prompt=num_waveforms_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e1917baa61e28abf0970080c90b5864579cbe8f9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_aura_flow import AuraFlowPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a10002c6fc0e0e004b73746d846b01b4bfae67ea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__pycache__/pipeline_aura_flow.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__pycache__/pipeline_aura_flow.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d269d5451ccc7756e3fcf5e130aa434d019f948 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/__pycache__/pipeline_aura_flow.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/pipeline_aura_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..c609bc8b29f88c2be0d373852aaa846038d55319 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -0,0 +1,671 @@ +# Copyright 2025 AuraFlow Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from transformers import T5Tokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import AuraFlowLoraLoaderMixin +from ...models import AuraFlowTransformer2DModel, AutoencoderKL +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AuraFlowPipeline + + >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("aura_flow.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): + r""" + Args: + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. AuraFlow uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`AuraFlowTransformer2DModel`]): + Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKL, + transformer: AuraFlowTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if device is None: + device = self._execution_device + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = max_sequence_length + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_input_ids = text_inputs["input_ids"] + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + prompt_embeds = self.text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.reshape(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + uncond_input = {k: v.to(device) for k, v in uncond_input.items()} + negative_prompt_embeds = self.text_encoder(**uncond_input)[0] + negative_prompt_attention_mask = ( + uncond_input["attention_mask"].unsqueeze(-1).expand(negative_prompt_embeds.shape) + ) + negative_prompt_embeds = negative_prompt_embeds * negative_prompt_attention_mask + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.reshape(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + num_inference_steps: int = 50, + sigmas: list[float] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + height: int | None = 1024, + width: int | None = 1024, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 256, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ) -> ImagePipelineOutput | tuple: + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + + # 2. Determine batch size. + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + + # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0]) + timestep = timestep.to(latents.device, dtype=latents.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c245313e2f8ac578542b2db88750cea64de67b94 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/__init__.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +import numpy as np +import PIL +from PIL import Image + +from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .blip_image_processing import BlipImageProcessor + from .modeling_blip2 import Blip2QFormerModel + from .modeling_ctx_clip import ContextCLIPTextModel + from .pipeline_blip_diffusion import BlipDiffusionPipeline diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/blip_image_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a0186f041d63e9274bdbcc2f8b53fb072175a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/blip_image_processing.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BLIP.""" + +import numpy as np +import torch +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from transformers.utils import TensorType, is_vision_available, logging + +from diffusers.utils import numpy_to_pil + + +if is_vision_available(): + import PIL.Image + + +logger = logging.get_logger(__name__) + + +# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop +# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor +class BlipImageProcessor(BaseImageProcessor): + r""" + Constructs a BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: int | float = 1 / 255, + do_normalize: bool = True, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + do_convert_rgb: bool = True, + do_center_crop: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_center_crop = do_center_crop + + # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: str | ChannelDimension | None = None, + input_data_format: str | ChannelDimension | None = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool | None = None, + size: dict[str, int] | None = None, + resample: PILImageResampling = None, + do_rescale: bool | None = None, + do_center_crop: bool | None = None, + rescale_factor: float | None = None, + do_normalize: bool | None = None, + image_mean: float | list[float] | None = None, + image_std: float | list[float] | None = None, + return_tensors: str | TensorType | None = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: str | ChannelDimension | None = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None or resample is None: + raise ValueError("Size and resample must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + if do_center_crop: + images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_outputs + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess(self, sample: torch.Tensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + # Output_type must be 'pil' + sample = numpy_to_pil(sample) + return sample diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/modeling_blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..c434ccdaccca70a3a3cb2f02b015b417071a0950 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -0,0 +1,636 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import BertTokenizer +from transformers.activations import QuickGELUActivation as QuickGELU +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig +from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Encoder, + Blip2PreTrainedModel, + Blip2QFormerAttention, + Blip2QFormerIntermediate, + Blip2QFormerOutput, +) +from transformers.pytorch_utils import apply_chunking_to_forward +from transformers.utils import ( + logging, + replace_return_docstrings, +) + + +logger = logging.get_logger(__name__) + + +# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py. +# But it doesn't support getting multimodal embeddings. So, this module can be +# replaced with a future `transformers` version supports that. +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + batch_size = embeddings.shape[0] + # repeat the query embeddings for batch size + query_embeds = query_embeds.repeat(batch_size, 1, 1) + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = embeddings.to(query_embeds.dtype) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 +class Blip2VisionEmbeddings(nn.Module): + def __init__(self, config: Blip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings +class Blip2QFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled(): + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + layer_outputs = self._gradient_checkpointing_func( + layer_module, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if layer_module.has_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layers making up the Qformer encoder +class Blip2QFormerLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = Blip2QFormerIntermediate(config) + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + self.output = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder +class ProjLayer(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): + super().__init__() + + # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm + self.dense1 = nn.Linear(in_dim, hidden_dim) + self.act_fn = QuickGELU() + self.dense2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(drop_p) + + self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) + + def forward(self, x): + x_in = x + + x = self.LayerNorm(x) + x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in + + return x + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 +class Blip2VisionModel(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2VisionConfig + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = Blip2VisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Blip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +# Qformer model, used to get multimodal embeddings from the text and image inputs +class Blip2QFormerModel(Blip2PreTrainedModel): + """ + Querying Transformer (Q-Former), used in BLIP-2. + """ + + def __init__(self, config: Blip2Config): + super().__init__(config) + self.config = config + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.visual_encoder = Blip2VisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + if not hasattr(config, "tokenizer") or config.tokenizer is None: + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + else: + self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right") + self.tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.proj_layer = ProjLayer( + in_dim=config.qformer_config.hidden_size, + out_dim=config.qformer_config.hidden_size, + hidden_dim=config.qformer_config.hidden_size * 4, + drop_p=0.1, + eps=1e-12, + ) + + self.encoder = Blip2QFormerEncoder(config.qformer_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + text_input=None, + image_input=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of: + shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and + value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are + used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape + `(batch_size, sequence_length)`. + use_cache (`bool`, `optional`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + + text = self.tokenizer(text_input, return_tensors="pt", padding=True) + text = text.to(self.device) + input_ids = text.input_ids + batch_size = input_ids.shape[0] + query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + ) + + query_length = self.query_tokens.shape[1] + + embedding_output = self.embeddings( + input_ids=input_ids, + query_embeds=self.query_tokens, + past_key_values_length=past_key_values_length, + ) + + # embedding_output = self.layernorm(query_embeds) + # embedding_output = self.dropout(embedding_output) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state + # image_embeds_frozen = torch.ones_like(image_embeds_frozen) + encoder_hidden_states = image_embeds_frozen + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + return self.proj_layer(sequence_output[:, :query_length, :]) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c5364f8985aabd82679caf5bfc20d04c1d93686e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -0,0 +1,221 @@ +# Copyright 2025 Salesforce.com, inc. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import CLIPPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.configuration_clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import CLIPEncoder + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip +# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer +# They pass through the clip model, along with the text embeddings, and interact with them using self attention +class ContextCLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = ContextCLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + ctx_embeddings: torch.Tensor = None, + ctx_begin_pos: list = None, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPooling: + return self.text_model( + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ContextCLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ContextCLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutputWithPooling: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + ) + + bsz, seq_len = input_shape + if ctx_embeddings is not None: + seq_len += ctx_embeddings.size(1) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=input_ids.device), + input_ids.to(torch.int).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class ContextCLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + if ctx_embeddings is None: + ctx_len = 0 + else: + ctx_len = ctx_embeddings.shape[1] + + seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + # for each input embeddings, add the ctx embeddings at the correct position + input_embeds_ctx = [] + bsz = inputs_embeds.shape[0] + + if ctx_embeddings is not None: + for i in range(bsz): + cbp = ctx_begin_pos[i] + + prefix = inputs_embeds[i, :cbp] + # remove the special token embedding + suffix = inputs_embeds[i, cbp:] + + input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0)) + + inputs_embeds = torch.stack(input_embeds_ctx, dim=0) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3dbdae966b35061ea70f3184d25a250c8e9b63 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -0,0 +1,355 @@ +# Copyright 2025 Salesforce.com, inc. +# Copyright 2025 The HuggingFace Team. All rights reserved.# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import PIL.Image +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .blip_image_processing import BlipImageProcessor +from .modeling_blip2 import Blip2QFormerModel +from .modeling_ctx_clip import ContextCLIPTextModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained( + ... "Salesforce/blipdiffusion", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> cond_subject = "dog" + >>> tgt_subject = "dog" + >>> text_prompt_input = "swimming underwater" + + >>> cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 25 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt_input, + ... cond_image, + ... cond_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: list[float] = None, + std: list[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str], + reference_image: PIL.Image.Image, + source_subject_category: list[str], + target_subject_category: list[str], + latents: torch.Tensor | None = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + neg_prompt: str | None = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`list[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + source_subject_category (`list[str]`): + The source subject category. + target_subject_category (`list[str]`): + The target subject category. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + + if XLA_AVAILABLE: + xm.mark_step() + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60e319ac7910e9c42320f2059b33fbd12b44b584 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bria"] = ["BriaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_bria import BriaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/pipeline_bria.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/pipeline_bria.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae9ce96e7ee1224d558b5284f0c842fe696e72 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/pipeline_bria.py @@ -0,0 +1,741 @@ +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models import AutoencoderKL +from ...models.transformers.transformer_bria import BriaTransformer2DModel +from ...pipelines import DiffusionPipeline +from ...pipelines.bria.pipeline_output import BriaPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...schedulers import ( + DDIMScheduler, + EulerAncestralDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, + KarrasDiffusionSchedulers, +) +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import BriaPipeline + + >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + # BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32. + + >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16) + >>> for block in pipe.text_encoder.encoder.block: + ... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32) + # BRIA's VAE is not supported in mixed precision, so we use float32. + + >>> if pipe.vae.config.shift_factor == 0: + ... pipe.vae.to(dtype=torch.float32) + + >>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus." + >>> image = pipe(prompt).images[0] + >>> image.save("bria.png") + ``` +""" + + +def is_ng_none(negative_prompt): + return ( + negative_prompt is None + or negative_prompt == "" + or (isinstance(negative_prompt, list) and negative_prompt[0] is None) + or (type(negative_prompt) == list and negative_prompt[0] == "") + ) + + +def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000): + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + sigmas = timesteps / num_train_timesteps + + inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)] + new_sigmas = sigmas[inds] + return new_sigmas + + +class BriaPipeline(DiffusionPipeline): + r""" + Based on FluxPipeline with several changes: + - no pooled embeddings + - We use zero padding for prompts + - No guidance embedding since this is not a distilled version + + Args: + transformer ([`BriaTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Bria uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler | KarrasDiffusionSchedulers, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k + + if self.vae.config.shift_factor is None: + self.vae.config.shift_factor = 0 + self.vae.to(dtype=torch.float32) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 128, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + if not is_ng_none(negative_prompt): + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, negative_prompt_embeds, text_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @attention_kwargs.setter + def attention_kwargs(self, value): + self._attention_kwargs = value + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 128, + device: torch.device | None = None, + ): + tokenizer = self.tokenizer + text_encoder = self.text_encoder + device = device or text_encoder.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + prompt_embeds_list = [] + for p in prompt: + text_inputs = tokenizer( + p, + # padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + # Concat zeros to max_sequence + b, seq_len, dim = prompt_embeds.shape + if seq_len < max_sequence_length: + padding = torch.zeros( + (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.concat([prompt_embeds, padding], dim=1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=0) + prompt_embeds = prompt_embeds.to(device=device) + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + return prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // self.vae_scale_factor) + width = 2 * (int(width) // self.vae_scale_factor) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + + return latents + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1) + latent_image_ids = latent_image_ids.reshape( + batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: list[int] = None, + guidance_scale: float = 5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 128, + clip_value: None | float = None, + normalize: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass + `prompt_embeds`. instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best + results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best + results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` + argument in their `set_timesteps` method. If not defined, the default behavior when + `num_inference_steps` is passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + <<<<<<< HEAD + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely + linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + ======= + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is + enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + >>>>>>> main + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch + generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for + image generation. Can be used to tweak the same generation with different prompts. If not + provided, a latents tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined + under `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is + called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, + timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as + specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables + listed in the `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list + with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self.attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if ( + isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler) + and self.scheduler.config["use_dynamic_shifting"] + ): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = latents.shape[1] + + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + mu=mu, + ) + else: + # 4. Prepare timesteps + # Sample from training sigmas + if isinstance(self.scheduler, DDIMScheduler) or isinstance( + self.scheduler, EulerAncestralDiscreteScheduler + ): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, None, None + ) + else: + sigmas = get_original_sigmas( + num_train_timesteps=self.scheduler.config.num_train_timesteps, + num_inference_steps=num_inference_steps, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if type(self.scheduler) != FlowMatchEulerDiscreteScheduler: + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=self.attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + cfg_noise_pred_text = noise_pred_text.std() + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if normalize: + noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred + + if clip_value: + assert clip_value > 0 + noise_pred = noise_pred.clip(-clip_value, clip_value) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf2ed949c26b84e3e10d345caa28ee7ab190abb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class BriaPipelineOutput(BaseOutput): + """ + Output class for Bria pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd77270902cfa436b4d5ccf130349c0c963b924 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_bria_fibo import BriaFiboPipeline + from .pipeline_bria_fibo_edit import BriaFiboEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py new file mode 100644 index 0000000000000000000000000000000000000000..1f178066b17d9836cb5755f64ef35e2749ecc692 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -0,0 +1,838 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +from typing import Any, Callable + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + + pipe = BriaFiboPipeline.from_pretrained( + "briaai/FIBO", + trust_remote_code=True, + torch_dtype=torch.bfloat16, + ) + pipe.enable_model_cpu_offload() + + with torch.inference_mode(): + # 1. Create a prompt to generate an initial image + output = vlm_pipe(prompt="a beautiful dog") + json_prompt_generate = output.values["json_prompt"] + + # Generate the image from the structured json prompt + results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5) + results_generate.images[0].save("image_generate.png") + ``` +""" + + +class BriaFiboPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler | KarrasDiffusionSchedulers, + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 64 + + def get_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 3000, + lora_scale: float | None = None, + ): + r""" + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: list[int] = None, + guidance_scale: float = 5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # 5. Prepare latent variables + + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..c2327bbce1c7820312652f3476dcda2f4acd9e0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py @@ -0,0 +1,1133 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PipelineMaskInput = Union[ + torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray] +] + +# TODO: Update example docstring +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboEditPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + vlm_pipe = vlm_pipe.init_pipeline() + + pipe = BriaFiboEditPipeline.from_pretrained( + "briaai/fibo-edit", + torch_dtype=torch.bfloat16, + ) + pipe.to("cuda") + + output = vlm_pipe( + prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." + ) + json_prompt_generate = json.loads(output.values["json_prompt"]) + + image = Image.open("image_generate.png") + + edit_prompt = "Make the owl to be a cat" + + json_prompt_generate["edit_instruction"] = edit_prompt + + results_generate = pipe( + prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" + ) + ``` +""" + +PREFERRED_RESOLUTION = { + 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)], + 512 * 512: [ + (416, 624), + (432, 592), + (464, 560), + (512, 512), + (544, 480), + (576, 448), + (592, 432), + (608, 416), + (624, 416), + (640, 400), + (672, 384), + (704, 368), + ], + 1024 * 1024: [ + (832, 1248), + (880, 1184), + (912, 1136), + (1024, 1024), + (1136, 912), + (1184, 880), + (1216, 848), + (1248, 832), + (1248, 832), + (1264, 816), + (1296, 800), + (1360, 768), + ], +} + + +def is_valid_edit_json(json_input: str | dict): + """ + Check if the input is a valid JSON string or dict with an "edit_instruction" key. + + Args: + json_input (`str` or `dict`): + The JSON string or dict to check. + + Returns: + `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise. + """ + try: + if isinstance(json_input, str) and "edit_instruction" in json_input: + json.loads(json_input) + return True + elif isinstance(json_input, dict) and "edit_instruction" in json_input: + return True + else: + return False + except json.JSONDecodeError: + return False + + +def is_valid_mask(mask: PipelineMaskInput): + """ + Check if the mask is a valid mask. + """ + if isinstance(mask, torch.Tensor): + return True + elif isinstance(mask, Image.Image): + return True + elif isinstance(mask, list): + return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask) + elif isinstance(mask, np.ndarray): + return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1 + else: + return False + + +def get_mask_size(mask: PipelineMaskInput): + """ + Get the size of the mask. + """ + if isinstance(mask, torch.Tensor): + return mask.shape[-2:] + elif isinstance(mask, Image.Image): + return mask.size[::-1] # (height, width) + elif isinstance(mask, list): + return [get_mask_size(m) for m in mask] + elif isinstance(mask, np.ndarray): + return mask.shape[-2:] + else: + return None + + +def get_image_size(image: PipelineImageInput): + """ + Get the size of the image. + """ + if isinstance(image, torch.Tensor): + return image.shape[-2:] + elif isinstance(image, Image.Image): + return image.size[::-1] # (height, width) + elif isinstance(image, list): + return [get_image_size(i) for i in image] + else: + return None + + +def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput): + """convert mask and image to PIL Images and paste the mask on the image""" + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, Image.Image): + pass + elif isinstance(mask, list): + mask = mask[0] + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, Image.Image): + pass + elif isinstance(image, list): + image = image[0] + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + + mask = mask.convert("L") + image = image.convert("RGB") + gray_color = (128, 128, 128) + gray_img = Image.new("RGB", image.size, gray_color) + image = Image.composite(gray_img, image, mask) + return image + + +class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) + self.default_sample_size = 32 # 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: bool | None = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + mask: Optional[PipelineMaskInput] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + seed: int | None = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*): + The image to guide the image generation. If not defined, the pipeline will generate an image from + scratch. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + seed (`int`, *optional*): + A seed used to make generation deterministic. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if height is None or width is None: + if image is not None: + image_height, image_width = self.image_processor.get_default_height_width(image) + if _auto_resize: + image_width, image_height = min( + PREFERRED_RESOLUTION[1024 * 1024], + key=lambda size: abs(size[0] / size[1] - image_width / image_height), + ) + width, height = image_width, image_height + else: + raise ValueError("You must provide either an image or both height and width.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + seed=seed, + image=image, + mask=mask, + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if mask is not None and image is not None: + image = paste_mask_on_image(mask, image) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + + if prompt is not None and is_valid_edit_json(prompt): + prompt = json.dumps(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if generator is None and seed is not None: + generator = torch.Generator(device=device).manual_seed(seed) + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + if image is not None: + image_latents, image_ids = self.prepare_image_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension + else: + image_latents = None + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + if image_latents is None: + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + else: + image_latent_attention_mask = torch.ones( + [image_latents.shape[0], image_latents.shape[1]], + dtype=image_latents.dtype, + device=image_latents.device, + ) + if guidance_scale > 1: + image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 + ) + + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def prepare_image_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None = None, + ): + image = image.to(device=device, dtype=dtype) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + # scaling + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean + latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] + image_latents_cthw = torch.concat(latents_scaled, dim=0) + image_latents_bchw = image_latents_cthw[:, :, 0, :, :] + + image_latent_height, image_latent_width = image_latents_bchw.shape[2:] + image_latents_bsd = self._pack_latents_no_patch( + latents=image_latents_bchw, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=image_latent_height, + width=image_latent_width, + ) + # breakpoint() + image_ids = self._prepare_latent_image_ids( + batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + return image_latents_bsd, image_ids + + def check_inputs( + self, + prompt, + seed, + image, + mask, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if seed is not None and not isinstance(seed, int): + raise ValueError("Seed must be an integer") + if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError("Image must be a valid image") + if image is None and mask is not None: + raise ValueError("If mask is provided, image must also be provided") + + if mask is not None and not is_valid_mask(mask): + raise ValueError("Mask must be a valid mask") + + if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)): + raise ValueError("Mask and image must have the same size") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not is_valid_edit_json(prompt): + raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def create_attention_matrix(self, attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..0c131db29d9f08d6767a4c11ac78055f020916ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/bria_fibo/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class BriaFiboPipelineOutput(BaseOutput): + """ + Output class for BriaFibo pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25069b5543c1d727e7f71630fbbfa389efdb5413 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_chroma"] = ["ChromaPipeline"] + _import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"] + _import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_chroma import ChromaPipeline + from .pipeline_chroma_img2img import ChromaImg2ImgPipeline + from .pipeline_chroma_inpainting import ChromaInpaintPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma.py new file mode 100644 index 0000000000000000000000000000000000000000..bc782107022de1b75fcd5b7d1f8dc4a87aea2667 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma.py @@ -0,0 +1,976 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ChromaTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import ChromaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ChromaPipeline + + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" + >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) + >>> pipe = ChromaPipeline.from_pretrained( + ... model_id, + ... transformer=transformer, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + >>> prompt = [ + ... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." + ... ] + >>> negative_prompt = [ + ... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" + ... ] + >>> image = pipe(prompt, negative_prompt=negative_prompt).images[0] + >>> image.save("chroma.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ChromaPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Chroma pipeline for text-to-image generation. + + Reference: https://huggingface.co/lodestones/Chroma1-HD/ + + Args: + transformer ([`ChromaTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: ChromaTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + tokenizer_mask = text_inputs.attention_mask + + tokenizer_mask_device = tokenizer_mask.to(device) + + # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + do_classifier_free_guidance: bool = True, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + negative_text_ids = None + + if do_classifier_free_guidance: + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_embeds=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError( + "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask" + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + def _prepare_attention_mask( + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, + ): + if attention_mask is None: + return attention_mask + + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], + dim=1, + ) + + return attention_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 35, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence. + Chroma requires a single padding token remain unmasked. Please refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + negative_prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative + prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + do_classifier_free_guidance=self.do_classifier_free_guidance, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=prompt_attention_mask, + ) + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + attention_mask=attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ChromaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f6e2f8d8afea3b9ce8bdaf9d18e662c7d9ca81 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -0,0 +1,1060 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ChromaTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import ChromaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline + + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" + >>> pipe = ChromaImg2ImgPipeline.from_pretrained( + ... model_id, + ... transformer=transformer, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + ... ) + >>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution" + >>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" + >>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0] + >>> image.save("chroma-img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ChromaImg2ImgPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Chroma pipeline for image-to-image generation. + + Reference: https://huggingface.co/lodestones/Chroma1-HD/ + + Args: + transformer ([`ChromaTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: ChromaTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + tokenizer_mask = text_inputs.attention_mask + + tokenizer_mask_device = tokenizer_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, attention_mask + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + do_classifier_free_guidance: bool = True, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + negative_text_ids = None + + if do_classifier_free_guidance: + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + device = device or self._execution_device + + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + height, + width, + strength, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError( + "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask" + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + def _prepare_attention_mask( + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, + ): + if attention_mask is None: + return attention_mask + + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)], + dim=1, + ) + attention_mask = attention_mask.to(dtype) + + return attention_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 35, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + strength: float = 0.9, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 35): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + strength (`float, *optional*, defaults to 0.9): + Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will + be used as a starting point, adding more noise to it the larger the strength. The number of denoising + steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum + and the denoising process will run for the full number of iterations specified in num_inference_steps. + A value of 1, therefore, essentially ignores image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence. + Chroma requires a single padding token remain unmasked. Please refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + negative_prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative + prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + strength, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + do_classifier_free_guidance=self.do_classifier_free_guidance, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=prompt_attention_mask, + ) + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, + ) + + # 6. Prepare image embeddings + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + attention_mask=attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ChromaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..52c2f7e51cf23970988230b323b3485bd2a3ed56 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py @@ -0,0 +1,1184 @@ +""" +ChromaInpaintPipeline implements a text-guided image inpainting pipeline for the lodestones/Chroma1-HD model, based on +the ChromaPipeline from Hugging Face Diffusers:contentReference[oaicite:0]{index=0} and the Stable Diffusion inpainting +approach:contentReference[oaicite:1]{index=1}. +""" + +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ChromaTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..chroma.pipeline_output import ChromaPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ChromaInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ChromaInpaintPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("chroma_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ChromaInpaintPipeline( + DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, FluxIPAdapterMixin +): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`ChromaTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: ChromaTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str], None] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + tokenizer_mask = text_inputs.attention_mask + + tokenizer_mask = tokenizer_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask, + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + seq_lengths = tokenizer_mask.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str], None] = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + do_classifier_free_guidance: bool = True, + max_sequence_length: int = 256, + lora_scale: bool | None = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device, dtype=dtype) + negative_text_ids = None + + if do_classifier_free_guidance: + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3, device=device, dtype=dtype) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError( + "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask" + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = ( + masked_image_latents - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def _prepare_attention_mask( + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, + ): + if attention_mask is None: + return attention_mask + + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], + dim=1, + ) + attention_mask = attention_mask.to(dtype) + + return attention_mask + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 1.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 35): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + strength (`float, *optional*, defaults to 0.9): + Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will + be used as a starting point, adding more noise to it the larger the strength. The number of denoising + steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum + and the denoising process will run for the full number of iterations specified in num_inference_steps. + A value of 1, therefore, essentially ignores image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence. + Chroma requires a single padding token remain unmasked. Please refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + negative_prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative + prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.chroma.ChromaPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + output_type=output_type, + strength=strength, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + image=image, + mask_image=mask_image, + padding_mask_crop=padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + do_classifier_free_guidance=self.do_classifier_free_guidance, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=prompt_attention_mask, + ) + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + attention_mask=attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ChromaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..229b0fe42b90baf2d4ebabac0a96864a3079ab9c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chroma/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class ChromaPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cffe4660977f28c13fddb0f8948b419d43efa79a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_chronoedit"] = ["ChronoEditPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_chronoedit import ChronoEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c1f4d723202d1031fe605177b4c1bec5b1b4b1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/pipeline_chronoedit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/pipeline_chronoedit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adcd69e749c3e755646da0ce362a7b5e5d74f853 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/pipeline_chronoedit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12927bc8407818fe9e85901a5059c17ed868580e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/pipeline_chronoedit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/pipeline_chronoedit.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0cc0ea5c2aa60167b9d39c72e78d95ae1fe1a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/pipeline_chronoedit.py @@ -0,0 +1,750 @@ +# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, ChronoEditTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import ChronoEditPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import CLIPVisionModel + + >>> # Available models: nvidia/ChronoEdit-14B-Diffusers + >>> model_id = "nvidia/ChronoEdit-14B-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> transformer = ChronoEditTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = ChronoEditPipeline.from_pretrained( + ... model_id, vae=vae, image_encoder=image_encoder, transformer=transformer, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image("https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png") + >>> max_area = 720 * 1280 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. " + ... "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood." + ... ) + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... height=height, + ... width=width, + ... num_frames=5, + ... guidance_scale=5.0, + ... enable_temporal_reasoning=False, + ... num_temporal_reasoning_steps=0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class ChronoEditPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + transformer: ChronoEditTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: torch.device | None = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + enable_temporal_reasoning: bool = False, + num_temporal_reasoning_steps: int = 0, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ChronoEditPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + enable_temporal_reasoning (`bool`, *optional*, defaults to `False`): + Whether to enable temporal reasoning. + num_temporal_reasoning_steps (`int`, *optional*, defaults to `0`): + The number of steps to enable temporal reasoning. + + Examples: + + Returns: + [`~ChronoEditPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`ChronoEditPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + ) + + num_frames = 5 if not enable_temporal_reasoning else num_frames + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if enable_temporal_reasoning and i == num_temporal_reasoning_steps: + latents = latents[:, :, [0, -1]] + condition = condition[:, :, [0, -1]] + + for j in range(len(self.scheduler.model_outputs)): + if self.scheduler.model_outputs[j] is not None: + if latents.shape[-3] != self.scheduler.model_outputs[j].shape[-3]: + self.scheduler.model_outputs[j] = self.scheduler.model_outputs[j][:, :, [0, -1]] + if self.scheduler.last_sample is not None: + self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]] + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + if enable_temporal_reasoning and latents.shape[2] > 2: + video_edit = self.vae.decode(latents[:, :, [0, -1]], return_dict=False)[0] + video_reason = self.vae.decode(latents[:, :, :-1], return_dict=False)[0] + video = torch.cat([video_reason, video_edit[:, :, 1:]], dim=2) + else: + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return ChronoEditPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6247ce9f3a0c57e5ee3e185d4c31211261e4379c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/chronoedit/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class ChronoEditPipelineOutput(BaseOutput): + r""" + Output class for ChronoEdit pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fa1dda53d3f6500d408a2adc199a27abfeb8c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"] + _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] + _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_cogvideox import CogVideoXPipeline + from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline + from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline + from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py new file mode 100644 index 0000000000000000000000000000000000000000..b883e10a673235c71cd6cd1167fa5dfcf6efd31d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -0,0 +1,793 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: CogVideoXDDIMScheduler | CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 226, + ) -> CogVideoXPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + # Discard any padding frames that were added for CogVideoX 1.5 + latents = latents[:, additional_frames:] + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..de5b969a9adc94fadda91b599fe550686b3bcfa1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -0,0 +1,846 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import torch +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_video, load_video + + >>> pipe = CogVideoXFunControlPipeline.from_pretrained( + ... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> control_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> prompt = ( + ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " + ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " + ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " + ... "moons, but the remainder of the scene is mostly realistic." + ... ) + + >>> video = pipe(prompt=prompt, control_video=control_video).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for controlled text-to-video generation using CogVideoX Fun. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->vae->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366 + def prepare_control_latents( + self, mask: torch.Tensor | None = None, masked_image: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if mask is not None: + masks = [] + for i in range(mask.size(0)): + current_mask = mask[i].unsqueeze(0) + current_mask = self.vae.encode(current_mask)[0] + current_mask = current_mask.mode() + masks.append(current_mask) + mask = torch.cat(masks, dim=0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + mask_pixel_values = [] + for i in range(masked_image.size(0)): + mask_pixel_value = masked_image[i].unsqueeze(0) + mask_pixel_value = self.vae.encode(mask_pixel_value)[0] + mask_pixel_value = mask_pixel_value.mode() + mask_pixel_values.append(mask_pixel_value) + masked_image_latents = torch.cat(mask_pixel_values, dim=0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + control_video=None, + control_video_latents=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if control_video is not None and control_video_latents is not None: + raise ValueError( + "Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + control_video: list[Image.Image] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + control_video_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 226, + ) -> CogVideoXPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + control_video (`list[PIL.Image.Image]`): + The control video to condition the generation on. Must be a list of images/frames of the video. If not + provided, `control_video_latents` must be provided. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 6.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + control_video_latents (`torch.Tensor`, *optional*): + Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for + controlled video generation. If not provided, `control_video` must be provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if control_video is not None and isinstance(control_video[0], Image.Image): + control_video = [control_video] + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + control_video, + control_video_latents, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + raise ValueError( + f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " + f"contains {latent_frames=}, which is not divisible." + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if control_video_latents is None: + control_video = self.video_processor.preprocess_video(control_video, height=height, width=width) + control_video = control_video.to(device=device, dtype=prompt_embeds.dtype) + + _, control_video_latents = self.prepare_control_latents(None, control_video) + control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_control_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ) + latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..9687d63bc7bf2a95b064f6bb0c1fd441543f6b5e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -0,0 +1,903 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import PIL +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import CogVideoXImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> video = pipe(image, prompt, use_dynamic_cfg=True) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: CogVideoXDDIMScheduler | CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + # For CogVideoX1.5, the latent should add 1 for padding (Not use) + if self.transformer.config.patch_size_t is not None: + shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:] + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + + if not self.vae.config.invert_scale_latents: + image_latents = self.vae_scaling_factor_image * image_latents + else: + # This is awkward but required because the CogVideoX team forgot to multiply the + # scaling factor during training :) + image_latents = 1 / self.vae_scaling_factor_image * image_latents + + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + # Select the first frame along the second dimension + if self.transformer.config.patch_size_t is not None: + first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...] + image_latents = torch.cat([first_frame, image_latents], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 226, + ) -> CogVideoXPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._current_timestep = None + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Create ofs embeds if required + ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + # Discard any padding frames that were added for CogVideoX 1.5 + latents = latents[:, additional_frames:] + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ce8292fad6efd1e3b89fbad206cb02fbf3f14b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -0,0 +1,872 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import torch +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + >>> input_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> prompt = ( + ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " + ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " + ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " + ... "moons, but the remainder of the scene is mostly realistic." + ... ) + + >>> video = pipe( + ... video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50 + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for video-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: CogVideoXDDIMScheduler | CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 60, + width: int = 90, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + timestep: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + init_latents = self.vae_scaling_factor_image * init_latents + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + strength, + negative_prompt, + callback_on_step_end_tensor_inputs, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: list[Image.Image] = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + strength: float = 0.8, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 226, + ) -> CogVideoXPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + video (`list[PIL.Image.Image]`): + The input video to condition the generation on. Must be a list of images/frames of the video. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + strength (`float`, *optional*, defaults to 0.8): + Higher strength leads to more differences between original video and generated video. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = len(video) if latents is None else latents.size(1) + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + strength=strength, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + video=video, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + if patch_size_t is not None and latent_frames % patch_size_t != 0: + raise ValueError( + f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video " + f"contains {latent_frames=}, which is not divisible." + ) + + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=prompt_embeds.dtype) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + latent_timestep, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..56ff501322318327bd7a48fee7884aceb3b984c0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogvideo/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50895251ba0b67dc2848b99654cbdd08312833f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_cogview3plus import CogView3PlusPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/pipeline_cogview3plus.py new file mode 100644 index 0000000000000000000000000000000000000000..8880e3a0d1e2da02956d64853dbaca89429ced01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -0,0 +1,686 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, CogView3PlusTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView3PlusPipeline + + >>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3-Plus-3B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView3PlusPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView3Plus. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogView3Plus uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogView3PlusTransformer2DModel`]): + A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: CogView3PlusTransformer2DModel, + scheduler: CogVideoXDDIMScheduler | CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 224, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt is None: + negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 224, + ) -> CogView3PipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`: + [`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + original_size = torch.cat([original_size, original_size]) + target_size = torch.cat([target_size, target_size]) + crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left]) + + original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6c89e117b74ccc1edb6e098bbdc84098d978230a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview3/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class CogView3PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a365e17fee7185cbfb67e857ec2d265c35674ab --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] + _import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_cogview4 import CogView4Pipeline + from .pipeline_cogview4_control import CogView4ControlPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4.py new file mode 100644 index 0000000000000000000000000000000000000000..329b76d11e0d404935d39820d671ff413de489a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -0,0 +1,687 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import CogView4LoraLoaderMixin +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4Pipeline + + >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GLMModel`]): + Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): + Tokenizer of class + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_glm_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 1024, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> CogView4PipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4_control.py new file mode 100644 index 0000000000000000000000000000000000000000..6282bf4cd7a405fd075daf5c1913f342f2c6753f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -0,0 +1,734 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4ControlPipeline + + >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ... ) + >>> prompt = "A bird in space" + >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0] + >>> image.save("cogview4-control.png") + ``` +""" + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4ControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GLMModel`]): + Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf). + tokenizer (`PreTrainedTokenizer`): + Tokenizer of class + [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline._get_glm_embeds + def _get_glm_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 1024, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.CogView4Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0, output_size=image.shape[0] * repeat_by) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> CogView4PipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain + tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels // 2 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + vae_shift_factor = 0 + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, control_image], dim=1).to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..997444c6c009a27bf944f84dbe6025b4345f87e1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cogview4/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class CogView4PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9ba330fbd129d104180a2c0e5d05672b48fa47 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_opencv_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available() and is_opencv_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects)) +else: + _import_structure["pipeline_consisid"] = ["ConsisIDPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_consisid import ConsisIDPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/consisid_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/consisid_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1646e15efbc48c36803ac71b4b6ce198075b35b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/consisid_utils.py @@ -0,0 +1,357 @@ +import importlib.util +import os + +import cv2 +import numpy as np +import torch +from PIL import Image, ImageOps +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import normalize, resize + +from ...utils import get_logger, load_image + + +logger = get_logger(__name__) + +_insightface_available = importlib.util.find_spec("insightface") is not None +_consisid_eva_clip_available = importlib.util.find_spec("consisid_eva_clip") is not None +_facexlib_available = importlib.util.find_spec("facexlib") is not None + +if _insightface_available: + import insightface + from insightface.app import FaceAnalysis +else: + raise ImportError("insightface is not available. Please install it using 'pip install insightface'.") + +if _consisid_eva_clip_available: + from consisid_eva_clip import create_model_and_transforms + from consisid_eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +else: + raise ImportError("consisid_eva_clip is not available. Please install it using 'pip install consisid_eva_clip'.") + +if _facexlib_available: + from facexlib.parsing import init_parsing_model + from facexlib.utils.face_restoration_helper import FaceRestoreHelper +else: + raise ImportError("facexlib is not available. Please install it using 'pip install facexlib'.") + + +def resize_numpy_image_long(image, resize_long_edge=768): + """ + Resize the input image to a specified long edge while maintaining aspect ratio. + + Args: + image (numpy.ndarray): Input image (H x W x C or H x W). + resize_long_edge (int): The target size for the long edge of the image. Default is 768. + + Returns: + numpy.ndarray: Resized image with the long edge matching `resize_long_edge`, while maintaining the aspect + ratio. + """ + + h, w = image.shape[:2] + if max(h, w) <= resize_long_edge: + return image + k = resize_long_edge / max(h, w) + h = int(h * k) + w = int(w * k) + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) + return image + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + return _totensor(imgs, bgr2rgb, float32) + + +def to_gray(img): + """ + Converts an RGB image to grayscale by applying the standard luminosity formula. + + Args: + img (torch.Tensor): The input image tensor with shape (batch_size, channels, height, width). + The image is expected to be in RGB format (3 channels). + + Returns: + torch.Tensor: The grayscale image tensor with shape (batch_size, 3, height, width). + The grayscale values are replicated across all three channels. + """ + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + +def process_face_embeddings( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + image, + original_id_image=None, + is_align_face=True, +): + """ + Process face embeddings from an image, extracting relevant features such as face embeddings, landmarks, and parsed + face features using a series of face detection and alignment tools. + + Args: + face_helper_1: Face helper object (first helper) for alignment and landmark detection. + clip_vision_model: Pre-trained CLIP vision model used for feature extraction. + face_helper_2: Face helper object (second helper) for embedding extraction. + eva_transform_mean: Mean values for image normalization before passing to EVA model. + eva_transform_std: Standard deviation values for image normalization before passing to EVA model. + app: Application instance used for face detection. + device: Device (CPU or GPU) where the computations will be performed. + weight_dtype: Data type of the weights for precision (e.g., `torch.float32`). + image: Input image in RGB format with pixel values in the range [0, 255]. + original_id_image: (Optional) Original image for feature extraction if `is_align_face` is False. + is_align_face: Boolean flag indicating whether face alignment should be performed. + + Returns: + tuple: + - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding + - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors. + - return_face_features_image_2: Processed face features image after normalization and parsing. + - face_kps: Keypoints of the face detected in the image. + """ + + face_helper_1.clean_all() + image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + # get antelopev2 embedding + face_info = app.get(image_bgr) + if len(face_info) > 0: + face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[ + -1 + ] # only use the maximum face + id_ante_embedding = face_info["embedding"] # (512,) + face_kps = face_info["kps"] + else: + id_ante_embedding = None + face_kps = None + + # using facexlib to detect and align face + face_helper_1.read_image(image_bgr) + face_helper_1.get_face_landmarks_5(only_center_face=True) + if face_kps is None: + face_kps = face_helper_1.all_landmarks_5[0] + face_helper_1.align_warp_face() + if len(face_helper_1.cropped_faces) == 0: + raise RuntimeError("facexlib align face fail") + align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB + + # in case insightface didn't detect face + if id_ante_embedding is None: + logger.warning("Failed to detect face using insightface. Extracting embedding with align face") + id_ante_embedding = face_helper_2.get_feat(align_face) + + id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512]) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) # torch.Size([1, 512]) + + # parsing + if is_align_face: + input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512]) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) # torch.Size([1, 3, 512, 512]) + # only keep the face features + return_face_features_image = torch.where(bg, white_image, to_gray(input)) # torch.Size([1, 3, 512, 512]) + return_face_features_image_2 = torch.where(bg, white_image, input) # torch.Size([1, 3, 512, 512]) + else: + original_image_bgr = cv2.cvtColor(original_id_image, cv2.COLOR_RGB2BGR) + input = img2tensor(original_image_bgr, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512]) + input = input.to(device) + return_face_features_image = return_face_features_image_2 = input + + # transform img before sending to eva-clip-vit + face_features_image = resize( + return_face_features_image, clip_vision_model.image_size, InterpolationMode.BICUBIC + ) # torch.Size([1, 3, 336, 336]) + face_features_image = normalize(face_features_image, eva_transform_mean, eva_transform_std) + id_cond_vit, id_vit_hidden = clip_vision_model( + face_features_image.to(weight_dtype), return_all_features=False, return_hidden=True, shuffle=False + ) # torch.Size([1, 768]), list(torch.Size([1, 577, 1024])) + id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) + id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) + + id_cond = torch.cat( + [id_ante_embedding, id_cond_vit], dim=-1 + ) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280]) + + return ( + id_cond, + id_vit_hidden, + return_face_features_image_2, + face_kps, + ) # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024])) + + +def process_face_embeddings_infer( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + img_file_path, + is_align_face=True, +): + """ + Process face embeddings from an input image for inference, including alignment, feature extraction, and embedding + concatenation. + + Args: + face_helper_1: Face helper object (first helper) for alignment and landmark detection. + clip_vision_model: Pre-trained CLIP vision model used for feature extraction. + face_helper_2: Face helper object (second helper) for embedding extraction. + eva_transform_mean: Mean values for image normalization before passing to EVA model. + eva_transform_std: Standard deviation values for image normalization before passing to EVA model. + app: Application instance used for face detection. + device: Device (CPU or GPU) where the computations will be performed. + weight_dtype: Data type of the weights for precision (e.g., `torch.float32`). + img_file_path: Path to the input image file (string) or a numpy array representing an image. + is_align_face: Boolean flag indicating whether face alignment should be performed (default: True). + + Returns: + tuple: + - id_cond: Concatenated tensor of Ante face embedding and CLIP vision embedding. + - id_vit_hidden: Hidden state of the CLIP vision model, a list of tensors. + - image: Processed face image after feature extraction and alignment. + - face_kps: Keypoints of the face detected in the image. + """ + + # Load and preprocess the input image + if isinstance(img_file_path, str): + image = np.array(load_image(image=img_file_path).convert("RGB")) + else: + image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB")) + + # Resize image to ensure the longer side is 1024 pixels + image = resize_numpy_image_long(image, 1024) + original_id_image = image + + # Process the image to extract face embeddings and related features + id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings( + face_helper_1, + clip_vision_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + app, + device, + weight_dtype, + image, + original_id_image, + is_align_face, + ) + + # Convert the aligned cropped face image (torch tensor) to a numpy array + tensor = align_crop_face_image.cpu().detach() + tensor = tensor.squeeze() + tensor = tensor.permute(1, 2, 0) + tensor = tensor.numpy() * 255 + tensor = tensor.astype(np.uint8) + image = ImageOps.exif_transpose(Image.fromarray(tensor)) + + return id_cond, id_vit_hidden, image, face_kps + + +def prepare_face_models(model_path, device, dtype): + """ + Prepare all face models for the facial recognition task. + + Parameters: + - model_path: Path to the directory containing model files. + - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded. + - dtype: Data type (e.g., torch.float32) for model inference. + + Returns: + - face_helper_1: First face restoration helper. + - face_helper_2: Second face restoration helper. + - face_clip_model: CLIP model for face extraction. + - eva_transform_mean: Mean value for image normalization. + - eva_transform_std: Standard deviation value for image normalization. + - face_main_model: Main face analysis model. + """ + # get helper model + face_helper_1 = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model="retinaface_resnet50", + save_ext="png", + device=device, + model_rootpath=os.path.join(model_path, "face_encoder"), + ) + face_helper_1.face_parse = None + face_helper_1.face_parse = init_parsing_model( + model_name="bisenet", device=device, model_rootpath=os.path.join(model_path, "face_encoder") + ) + face_helper_2 = insightface.model_zoo.get_model( + f"{model_path}/face_encoder/models/antelopev2/glintr100.onnx", providers=["CUDAExecutionProvider"] + ) + face_helper_2.prepare(ctx_id=0) + + # get local facial extractor part 1 + model, _, _ = create_model_and_transforms( + "EVA02-CLIP-L-14-336", + os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), + force_custom_clip=True, + ) + face_clip_model = model.visual + eva_transform_mean = getattr(face_clip_model, "image_mean", OPENAI_DATASET_MEAN) + eva_transform_std = getattr(face_clip_model, "image_std", OPENAI_DATASET_STD) + if not isinstance(eva_transform_mean, (list, tuple)): + eva_transform_mean = (eva_transform_mean,) * 3 + if not isinstance(eva_transform_std, (list, tuple)): + eva_transform_std = (eva_transform_std,) * 3 + eva_transform_mean = eva_transform_mean + eva_transform_std = eva_transform_std + + # get local facial extractor part 2 + face_main_model = FaceAnalysis( + name="antelopev2", root=os.path.join(model_path, "face_encoder"), providers=["CUDAExecutionProvider"] + ) + face_main_model.prepare(ctx_id=0, det_size=(640, 640)) + + # move face models to device + face_helper_1.face_det.eval() + face_helper_1.face_parse.eval() + face_clip_model.eval() + face_helper_1.face_det.to(device) + face_helper_1.face_parse.to(device) + face_clip_model.to(device, dtype=dtype) + + return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/pipeline_consisid.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/pipeline_consisid.py new file mode 100644 index 0000000000000000000000000000000000000000..20b779bf5aaa50a028315d8001d5fa430a69bd86 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/pipeline_consisid.py @@ -0,0 +1,972 @@ +# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, ConsisIDTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDPMScheduler +from ...utils import is_opencv_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import ConsisIDPipelineOutput + + +if is_opencv_available(): + import cv2 + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import ConsisIDPipeline + >>> from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer + >>> from diffusers.utils import export_to_video + >>> from huggingface_hub import snapshot_download + + >>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview") + >>> ( + ... face_helper_1, + ... face_helper_2, + ... face_clip_model, + ... face_main_model, + ... eva_transform_mean, + ... eva_transform_std, + ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16) + >>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + >>> prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." + >>> image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true" + + >>> id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + ... face_helper_1, + ... face_clip_model, + ... face_helper_2, + ... eva_transform_mean, + ... eva_transform_std, + ... face_main_model, + ... "cuda", + ... torch.bfloat16, + ... image, + ... is_align_face=True, + ... ) + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... num_inference_steps=50, + ... guidance_scale=6.0, + ... use_dynamic_cfg=False, + ... id_vit_hidden=id_vit_hidden, + ... id_cond=id_cond, + ... kps_cond=face_kps, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): + """ + This function draws keypoints and the limbs connecting them on an image. + + Parameters: + - image_pil (PIL.Image): Input image as a PIL object. + - kps (list of tuples): A list of keypoints where each keypoint is a tuple of (x, y) coordinates. + - color_list (list of tuples, optional): list of colors (in RGB format) for each keypoint. Default is a set of five + colors. + + Returns: + - PIL.Image: Image with the keypoints and limbs drawn. + """ + + stickwidth = 4 + limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) + kps = np.array(kps) + + w, h = image_pil.size + out_img = np.zeros([h, w, 3]) + + for i in range(len(limbSeq)): + index = limbSeq[i] + color = color_list[index[0]] + + x = kps[index][:, 0] + y = kps[index][:, 1] + length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) + polygon = cv2.ellipse2Poly( + (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1 + ) + out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) + out_img = (out_img * 0.6).astype(np.uint8) + + for idx_kp, kp in enumerate(kps): + color = color_list[idx_kp] + x, y = kp + out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) + + out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) + return out_img_pil + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + """ + This function calculates the resize and crop region for an image to fit a target width and height while preserving + the aspect ratio. + + Parameters: + - src (tuple): A tuple containing the source image's height (h) and width (w). + - tgt_width (int): The target width to resize the image. + - tgt_height (int): The target height to resize the image. + + Returns: + - tuple: Two tuples representing the crop region: + 1. The top-left coordinates of the crop region. + 2. The bottom-right coordinates of the crop region. + """ + + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using ConsisID. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. ConsisID uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`ConsisIDTransformer3DModel`]): + A text conditioned `ConsisIDTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: ConsisIDTransformer3DModel, + scheduler: CogVideoXDPMScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + num_frames: int = 13, + height: int = 60, + width: int = 90, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + kps_cond: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_frames, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + image = image.unsqueeze(2) # [B, C, F, H, W] + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [ + retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i]) + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + if kps_cond is not None: + kps_cond = kps_cond.unsqueeze(2) + kps_cond_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in kps_cond] + + image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + image_latents = self.vae_scaling_factor_image * image_latents + + if kps_cond is not None: + kps_cond_latents = torch.cat(kps_cond_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W] + kps_cond_latents = self.vae_scaling_factor_image * kps_cond_latents + + padding_shape = ( + batch_size, + num_frames - 2, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + else: + padding_shape = ( + batch_size, + num_frames - 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype) + if kps_cond is not None: + image_latents = torch.cat([image_latents, kps_cond_latents, latent_padding], dim=1) + else: + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents, image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + image, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = self.transformer.config.sample_width // self.transformer.config.patch_size + base_size_height = self.transformer.config.sample_height // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + device=device, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 226, + id_vit_hidden: torch.Tensor | None = None, + id_cond: torch.Tensor | None = None, + kps_cond: torch.Tensor | None = None, + ) -> ConsisIDPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_frames (`int`, defaults to `49`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 6): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + use_dynamic_cfg (`bool`, *optional*, defaults to `False`): + If True, dynamically adjusts the guidance scale during inference. This allows the model to use a + progressive guidance scale, improving the balance between text-guided generation and image quality over + the course of the inference steps. Typically, early inference steps use a higher guidance scale for + more faithful image generation, while later steps reduce it for more diverse and natural results. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + id_vit_hidden (`torch.Tensor | None`, *optional*): + The tensor representing the hidden features extracted from the face model, which are used to condition + the local facial extractor. This is crucial for the model to obtain high-frequency information of the + face. If not provided, the local facial extractor will not run normally. + id_cond (`torch.Tensor | None`, *optional*): + The tensor representing the hidden features extracted from the clip model, which are used to condition + the local facial extractor. This is crucial for the model to edit facial features If not provided, the + local facial extractor will not run normally. + kps_cond (`torch.Tensor | None`, *optional*): + A tensor that determines whether the global facial extractor use keypoint information for conditioning. + If provided, this tensor controls whether facial keypoints such as eyes, nose, and mouth landmarks are + used during the generation process. This helps ensure the model retains more facial low-frequency + information. + + Examples: + + Returns: + [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`: + [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image=image, + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + is_kps = getattr(self.transformer.config, "is_kps", False) + kps_cond = kps_cond if is_kps else None + if kps_cond is not None: + kps_cond = draw_kps(image, kps_cond) + kps_cond = self.video_processor.preprocess(kps_cond, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=prompt_embeds.dtype + ) + + latent_channels = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + kps_cond, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + timesteps_cpu = timesteps.cpu() + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + ( + 1 + - math.cos( + math.pi + * ((num_inference_steps - timesteps_cpu[i].item()) / num_inference_steps) ** 5.0 + ) + ) + / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return ConsisIDPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..83a5be8d230ba998c8ccd6f450a9d679592a9894 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consisid/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class ConsisIDPipelineOutput(BaseOutput): + r""" + Output class for ConsisID pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consistency_models/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consistency_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..162d91c010acf95aa2daf87c51ab1e0c68361fd5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consistency_models/__init__.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = { + "pipeline_consistency_models": ["ConsistencyModelPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_consistency_models import ConsistencyModelPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consistency_models/pipeline_consistency_models.py new file mode 100644 index 0000000000000000000000000000000000000000..85e59adc39a433c13d2a27391143bf19527899c5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -0,0 +1,282 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch + +from ...models import UNet2DModel +from ...schedulers import CMStochasticIterativeScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import ConsistencyModelPipeline + + >>> device = "cuda" + >>> # Load the cd_imagenet64_l2 checkpoint. + >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" + >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe.to(device) + + >>> # Onestep Sampling + >>> image = pipe(num_inference_steps=1).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample.png") + + >>> # Onestep sampling, class-conditional image generation + >>> # ImageNet-64 class label 145 corresponds to king penguins + >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") + + >>> # Multistep sampling, class-conditional image generation + >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo: + >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 + >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] + >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") + ``` +""" + + +class ConsistencyModelPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional or class-conditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only + compatible with [`CMStochasticIterativeScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + ) + + self.safety_checker = None + + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess_image(self, sample: torch.Tensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + + # Output_type must be 'pil' + sample = self.numpy_to_pil(sample) + return sample + + def prepare_class_labels(self, batch_size, device, class_labels=None): + if self.unet.config.num_class_embeds is not None: + if isinstance(class_labels, list): + class_labels = torch.tensor(class_labels, dtype=torch.int) + elif isinstance(class_labels, int): + assert batch_size == 1, "Batch size must be 1 if classes is an int" + class_labels = torch.tensor([class_labels], dtype=torch.int) + elif class_labels is None: + # Randomly generate batch_size class labels + # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils + class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) + class_labels = class_labels.to(device) + else: + class_labels = None + return class_labels + + def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): + if num_inference_steps is None and timesteps is None: + raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") + + if num_inference_steps is not None and timesteps is not None: + logger.warning( + f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" + " `timesteps` will be used over `num_inference_steps`." + ) + + if latents is not None: + expected_shape = (batch_size, 3, img_size, img_size) + if latents.shape != expected_shape: + raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + batch_size: int = 1, + class_labels: torch.Tensor | list[int] | int | None = None, + num_inference_steps: int = 1, + timesteps: list[int] = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + class_labels (`torch.Tensor` or `list[int]` or `int`, *optional*): + Optional class labels for conditioning class-conditional consistency models. Not used if the model is + not class-conditional. + num_inference_steps (`int`, *optional*, defaults to 1): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Prepare call parameters + img_size = self.unet.config.sample_size + device = self._execution_device + + # 1. Check inputs + self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) + + # 2. Prepare image latents + # Sample image latents x_0 ~ N(0, sigma_0^2 * I) + sample = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=img_size, + width=img_size, + dtype=self.unet.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 3. Handle class_labels for class-conditional models + class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Denoising loop + # Multistep sampling: implements Algorithm 1 in the paper + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + scaled_sample = self.scheduler.scale_model_input(sample, t) + model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] + + sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] + + # call the callback, if provided + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + if XLA_AVAILABLE: + xm.mark_step() + + # 6. Post-process image sample + image = self.postprocess_image(sample, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a49dccf235a3b84b2695bba6724ba913d3f60155 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__init__.py @@ -0,0 +1,86 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["multicontrolnet"] = ["MultiControlNetModel"] + _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] + _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] + _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] + _import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"] + _import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"] + _import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline + from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline + from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline + from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6948271f4e2f730dd34552848c9cf39f522df0cc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eb7bdc94bc61759d92211b37aed61f5e498b744 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/multicontrolnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e388ca4cf01a3bb9ff116135ee0b72c53864277f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_blip_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_blip_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0302e9530476f08f9c00599d372e6a10c2629f7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_blip_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c96ee2d167678fb7cb03f7426f1affc8fad47414 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..291ad7a3e4d021f33756a62805603bc9c35065c1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint_sd_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint_sd_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..153da7523f4582785c60b1c3dd1312f4c8bfca64 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_inpaint_sd_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80d535944f04b4d77f40fb611e6cb8528a44ea41 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e06d4e0d4a16c9ac0bfdc8caba49070ddb5a3c0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_sd_xl_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_inpaint_sd_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_inpaint_sd_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a660c6f73827cec1ca71fad22bc558f0f4eff5a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_inpaint_sd_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_sd_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_sd_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a11742c371077e3e9ec980273f098b766012aaf7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_sd_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_sd_xl_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_sd_xl_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13be8d3e313d92b390b5fbcd1c584b9a9ff6e02c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_controlnet_union_sd_xl_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_flax_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_flax_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3209eba07195227027b272a30d1fe4a7059c2509 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/__pycache__/pipeline_flax_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/multicontrolnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/multicontrolnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6526dd8c9a576076db072659c81b457b66af7a35 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/multicontrolnet.py @@ -0,0 +1,12 @@ +from ...models.controlnets.multicontrolnet import MultiControlNetModel +from ...utils import deprecate, logging + + +logger = logging.get_logger(__name__) + + +class MultiControlNetModel(MultiControlNetModel): + def __init__(self, *args, **kwargs): + deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead." + deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message) + super().__init__(*args, **kwargs) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..86fa135abff474070e958448c031a32f565eb6c7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -0,0 +1,1368 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + image = callback_outputs.pop("image", image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca92b906842f009e1e071e4c166729e5b08d684 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -0,0 +1,421 @@ +# Copyright 2025 Salesforce.com, inc. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import PIL.Image +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..blip_diffusion.blip_image_processing import BlipImageProcessor +from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel +from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline + >>> from diffusers.utils import load_image + >>> from controlnet_aux import CannyDetector + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( + ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> style_subject = "flower" + >>> tgt_subject = "teapot" + >>> text_prompt = "on a marble table" + + >>> cldm_cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg" + ... ).resize((512, 512)) + >>> canny = CannyDetector() + >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil") + >>> style_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 50 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt, + ... style_image, + ... cldm_cond_image, + ... style_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + controlnet ([`ControlNetModel`]): + ControlNet model to get the conditioning image embedding. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + controlnet: ControlNetModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: list[float] = None, + std: list[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + controlnet=controlnet, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.image_processor.preprocess( + image, + size={"width": width, "height": height}, + do_rescale=True, + do_center_crop=False, + do_normalize=False, + return_tensors="pt", + )["pixel_values"].to(device) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: list[str], + reference_image: PIL.Image.Image, + condtioning_image: PIL.Image.Image, + source_subject_category: list[str], + target_subject_category: list[str], + latents: torch.Tensor | None = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + neg_prompt: str | None = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`list[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + condtioning_image (`PIL.Image.Image`): + The conditioning canny edge image to condition the generation on. + source_subject_category (`list[str]`): + The source subject category. + target_subject_category (`list[str]`): + The target subject category. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + seed (`int`, *optional*, defaults to 42): + The seed to use for random generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) + # 3. unconditional embedding + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + cond_image = self.prepare_control_image( + image=condtioning_image, + width=width, + height=height, + batch_size=batch_size, + num_images_per_prompt=1, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=cond_image, + return_dict=False, + ) + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + + if XLA_AVAILABLE: + xm.mark_step() + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cfabc66f25bb7f39326fe3c40f7d17986a596c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -0,0 +1,1336 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> np_image = np.array(image) + + >>> # get canny image + >>> np_image = cv2.Canny(np_image, 100, 200) + >>> np_image = np_image[:, :, None] + >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) + >>> canny_image = Image.fromarray(np_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_image(image): + if isinstance(image, torch.Tensor): + # Batch single image + if image.ndim == 3: + image = image.unsqueeze(0) + + image = image.to(dtype=torch.float32) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + return image + + +class StableDiffusionControlNetImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.8, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The initial image to be used as the starting point for the image generation process. Can also accept + image latents as `image`, and if passing latents directly they are not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + # 5. Prepare controlnet_conditioning_image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..d34278d0086b9f2d918006af1f2cb13335212273 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -0,0 +1,1525 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting > + ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)) + > as well as default text-to-image Stable Diffusion checkpoints > + ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)). + > Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned + on > those, such as + [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "control_image", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + callback_steps, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.5, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, + `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both + NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, + `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, + W, 1)`, or `(H, W)`. + control_image (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, + `list[list[torch.Tensor]]`, or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 0.5): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + mask_image, + height, + width, + callback_steps, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..942bcb49083e2aaa2a78ea038e1d6169ee2bd04f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -0,0 +1,1865 @@ +# Copyright 2025 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((1024, 1024)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((1024, 1024)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLControlNetInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + "control_image", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor | None = None, + image_encoder: CLIPVisionModelWithProjection | None = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + masked_image_latents = None + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput | list[PipelineImageInput] = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # # 0.0 Default height and width to unet + # height = height or self.unet.config.sample_size * self.vae_scale_factor + # width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 0.1 align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.1 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 5. Preprocess mask and image - resizes image and mask w.r.t height and width + # 5.1 Prepare init image + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5.2 Prepare control images + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + raise ValueError(f"{controlnet.__class__} is not supported.") + + # 5.3 Prepare mask + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps if isinstance(controlnet, MultiControlNetModel) else keeps[0]) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + denoising_end is not None + and denoising_start is not None + and denoising_value_valid(denoising_end) + and denoising_value_valid(denoising_start) + and denoising_start >= denoising_end + ): + raise ValueError( + f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {denoising_end} when using type float." + ) + elif denoising_end is not None and denoising_value_valid(denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # # Resize control_image to match the size of the input to the controlnet + # if control_image.shape[-2:] != control_model_input.shape[-2:]: + # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..89e8d2e54123397c8f6c31e72b34d0de4056ddf7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -0,0 +1,1605 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + "image", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + image = callback_outputs.pop("image", image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..8c39856ec7cb102b36539d07d2e3a9e79206a147 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -0,0 +1,1665 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # pip install accelerate transformers safetensors diffusers + + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from transformers import DPTImageProcessor, DPTForDepthEstimation + >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL + >>> from diffusers.utils import load_image + + >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") + >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0-small", + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + + >>> def get_depth_map(image): + ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + ... with torch.no_grad(), torch.autocast("cuda"): + ... depth_map = depth_estimator(image).predicted_depth + + ... depth_map = torch.nn.functional.interpolate( + ... depth_map.unsqueeze(1), + ... size=(1024, 1024), + ... mode="bicubic", + ... align_corners=False, + ... ) + ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) + ... image = torch.cat([depth_map] * 3, dim=1) + ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + ... return image + + + >>> prompt = "A robot, 4k photo" + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((1024, 1024)) + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> depth_image = get_depth_map(image) + + >>> images = pipe( + ... prompt, + ... image=image, + ... control_image=depth_image, + ... strength=0.99, + ... num_inference_steps=50, + ... controlnet_conditioning_scale=controlnet_conditioning_scale, + ... ).images + >>> images[0].save(f"robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLControlNetImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "control_image", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.8, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also + be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in + init, images must be passed as a list such that each element of the list can be correctly batched for + input to a single controlnet. + height (`int`, *optional*, defaults to the size of control_image): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to the size of control_image): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image and controlnet_conditioning_image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(control_image, list): + original_size = original_size or control_image[0].shape[-2:] + else: + original_size = original_size or control_image.shape[-2:] + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..4b7ca284d63679bd307fe3a8dd2d7c01a1fe9dce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -0,0 +1,1896 @@ +# Copyright 2025 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL + from diffusers.utils import load_image + import torch + import numpy as np + from PIL import Image + + prompt = "A cat" + # download an image + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((1024, 1024)) + mask = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ).resize((1024, 1024)) + # initialize the models and pipeline + controlnet = ControlNetUnionModel.from_pretrained( + "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16 + ) + vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + pipe = StableDiffusionXLControlNetUnionInpaintPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + controlnet=controlnet, + vae=vae, + torch_dtype=torch.float16, + variant="fp16", + ) + pipe.enable_model_cpu_offload() + controlnet_img = image.copy() + controlnet_img_np = np.array(controlnet_img) + mask_np = np.array(mask) + controlnet_img_np[mask_np > 0] = 0 + controlnet_img = Image.fromarray(controlnet_img_np) + # generate image + image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0] + image.save("inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLControlNetUnionInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + "mask", + "masked_image_latents", + "control_image", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetUnionModel + | list[ControlNetUnionModel] + | tuple[ControlNetUnionModel] + | MultiControlNetUnionModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor | None = None, + image_encoder: CLIPVisionModelWithProjection | None = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + control_mode=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + + # Equal number of `image` and `control_mode` elements + if isinstance(controlnet, ControlNetUnionModel): + if len(image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_mode)") + elif isinstance(controlnet, MultiControlNetUnionModel): + if not all(isinstance(i, list) for i in control_mode): + raise ValueError( + "For multiple controlnets: elements of control_mode must be lists representing conditioning mode." + ) + + elif sum(len(x) for x in image) != sum(len(x) for x in control_mode): + raise ValueError("Expected len(control_image) == len(control_mode)") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_control_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + masked_image_latents = None + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput | list[PipelineImageInput] = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_mode: int | list[int] | list[list[int]] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + control_image (`PipelineImageInput` or `list[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `list[int]` or `list[list[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if not isinstance(control_image, list): + control_image = [control_image] + else: + control_image = control_image.copy() + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet_conditioning_scale, float): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.1 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=denoising_start if denoising_value_valid(denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 5. Preprocess mask and image - resizes image and mask w.r.t height and width + # 5.1 Prepare init image + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5.2 Prepare control images + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] + + # 5.3 Prepare mask + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, _ = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + denoising_end is not None + and denoising_start is not None + and denoising_value_valid(denoising_end) + and denoising_value_valid(denoising_start) + and denoising_start >= denoising_end + ): + raise ValueError( + f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {denoising_end} when using type float." + ) + elif denoising_end is not None and denoising_value_valid(denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + control_type_repeat_factor = ( + batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + ) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + for _control_type in control_type + ] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # # Resize control_image to match the size of the input to the controlnet + # if control_image.shape[-2:] != control_model_input.shape[-2:]: + # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + control_type=control_type, + control_type_idx=control_mode, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..87057c2392df6b782280873cfca35b5bd6f69b9a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -0,0 +1,1611 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install controlnet_aux + >>> from controlnet_aux import LineartAnimeDetector + >>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import torch + + >>> prompt = "A cat" + >>> # download an image + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" + ... ).resize((1024, 1024)) + >>> # initialize the models and pipeline + >>> controlnet = ControlNetUnionModel.from_pretrained( + ... "xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... torch_dtype=torch.float16, + ... variant="fp16", + ... ) + >>> pipe.enable_model_cpu_offload() + >>> # prepare image + >>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") + >>> controlnet_img = processor(image, output_type="pil") + >>> # generate image + >>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetUnionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetUnionModel`]`): + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetUnionModel + | list[ControlNetUnionModel] + | tuple[ControlNetUnionModel] + | MultiControlNetUnionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def check_inputs( + self, + prompt, + prompt_2, + image: PipelineImageInput, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + control_mode=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) + + # Check `controlnet_conditioning_scale` + if isinstance(controlnet, MultiControlNetUnionModel): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.") + + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + + # Equal number of `image` and `control_mode` elements + if isinstance(controlnet, ControlNetUnionModel): + if len(image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_mode)") + elif isinstance(controlnet, MultiControlNetUnionModel): + if not all(isinstance(i, list) for i in control_mode): + raise ValueError( + "For multiple controlnets: elements of control_mode must be lists representing conditioning mode." + ) + + elif sum(len(x) for x in image) != sum(len(x) for x in control_mode): + raise ValueError("Expected len(control_image) == len(control_mode)") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + control_image: PipelineImageInput | list[PipelineImageInput] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_mode: int | list[int] | list[list[int]] | None = None, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + control_image (`PipelineImageInput` or `list[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `list[int]` or `list[list[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if not isinstance(control_image, list): + control_image = [control_image] + else: + control_image = control_image.copy() + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet_conditioning_scale, float): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetUnionModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) + + # 7.2 Prepare added time ids & embeddings + original_size = original_size or (height, width) + target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + + control_type_repeat_factor = ( + batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + ) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + for _control_type in control_type + ] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + control_type=control_type, + control_type_idx=control_mode, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb87e57ea8a6951cac0b7901389ea2e7d1412c1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -0,0 +1,1711 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import ( + AutoencoderKL, + ControlNetUnionModel, + ImageProjection, + MultiControlNetUnionModel, + UNet2DConditionModel, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + # !pip install controlnet_aux + from diffusers import ( + StableDiffusionXLControlNetUnionImg2ImgPipeline, + ControlNetUnionModel, + AutoencoderKL, + ) + from diffusers.utils import load_image + import torch + from PIL import Image + import numpy as np + + prompt = "A cat" + # download an image + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png" + ) + # initialize the models and pipeline + controlnet = ControlNetUnionModel.from_pretrained( + "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16 + ) + vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + pipe = StableDiffusionXLControlNetUnionImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + controlnet=controlnet, + vae=vae, + torch_dtype=torch.float16, + variant="fp16", + ).to("cuda") + # `enable_model_cpu_offload` is not recommended due to multiple generations + height = image.height + width = image.width + ratio = np.sqrt(1024.0 * 1024.0 / (width * height)) + # 3 * 3 upscale correspond to 16 * 3 multiply, 2 * 2 correspond to 16 * 2 multiply and so on. + scale_image_factor = 3 + base_factor = 16 + factor = scale_image_factor * base_factor + W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor + image = image.resize((W, H)) + target_width = W // scale_image_factor + target_height = H // scale_image_factor + images = [] + crops_coords_list = [ + (0, 0), + (0, width // 2), + (height // 2, 0), + (width // 2, height // 2), + 0, + 0, + 0, + 0, + 0, + ] + for i in range(scale_image_factor): + for j in range(scale_image_factor): + left = j * target_width + top = i * target_height + right = left + target_width + bottom = top + target_height + cropped_image = image.crop((left, top, right, bottom)) + cropped_image = cropped_image.resize((W, H)) + images.append(cropped_image) + # set ControlNetUnion input + result_images = [] + for sub_img, crops_coords in zip(images, crops_coords_list): + new_width, new_height = W, H + out = pipe( + prompt=[prompt] * 1, + image=sub_img, + control_image=[sub_img], + control_mode=[6], + width=new_width, + height=new_height, + num_inference_steps=30, + crops_coords_top_left=(W, H), + target_size=(W, H), + original_size=(W * 2, H * 2), + ) + result_images.append(out.images[0]) + new_im = Image.new("RGB", (new_width * scale_image_factor, new_height * scale_image_factor)) + new_im.paste(result_images[0], (0, 0)) + new_im.paste(result_images[1], (new_width, 0)) + new_im.paste(result_images[2], (new_width * 2, 0)) + new_im.paste(result_images[3], (0, new_height)) + new_im.paste(result_images[4], (new_width, new_height)) + new_im.paste(result_images[5], (new_width * 2, new_height)) + new_im.paste(result_images[6], (0, new_height * 2)) + new_im.paste(result_images[7], (new_width, new_height * 2)) + new_im.paste(result_images[8], (new_width * 2, new_height * 2)) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLControlNetUnionImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetUnionModel`]): + Provides additional conditioning to the unet during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "add_text_embeds", "add_time_ids", "control_image"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetUnionModel + | list[ControlNetUnionModel] + | tuple[ControlNetUnionModel] + | MultiControlNetUnionModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetUnionModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + control_mode=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetUnionModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, ControlNetUnionModel): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + elif isinstance(controlnet, MultiControlNetUnionModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif not all(isinstance(i, list) for i in image): + raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for images_ in image: + for image_ in images_: + self.check_image(image_, prompt, prompt_embeds) + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if isinstance(controlnet, MultiControlNetUnionModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Check `control_mode` + if isinstance(controlnet, ControlNetUnionModel): + if max(control_mode) >= controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.") + elif isinstance(controlnet, MultiControlNetUnionModel): + for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets): + if max(_control_mode) >= _controlnet.config.num_control_type: + raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput | list[PipelineImageInput] = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.8, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_mode: int | list[int] | list[list[int]] | None = None, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`PipelineImageInput` or `list[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to the size of control_image): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to the size of control_image): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `list[int]` or `list[list[int]], *optional*): + The control condition types for the ControlNet. See the ControlNet's model card forinformation on the + available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list + where each ControlNet should have its corresponding control mode list. Should reflect the order of + conditions in control_image + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if not isinstance(control_image, list): + control_image = [control_image] + else: + control_image = control_image.copy() + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if isinstance(controlnet, MultiControlNetUnionModel): + control_image = [[item] for item in control_image] + control_mode = [[item] for item in control_mode] + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + if isinstance(controlnet_conditioning_scale, float): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode) + controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + control_mode, + callback_on_step_end_tensor_inputs, + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1) + for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets) + ] + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetUnionModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4.1 Prepare image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + # 4.2 Prepare control images + if isinstance(controlnet, ControlNetUnionModel): + control_images = [] + + for image_ in control_image: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + + elif isinstance(controlnet, MultiControlNetUnionModel): + control_images = [] + + for control_image_ in control_image: + images = [] + + for image_ in control_image_: + image_ = self.prepare_control_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + control_images.append(images) + + control_image = control_images + height, width = control_image[0][0].shape[-2:] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps) + + # 7.2 Prepare added time ids & embeddings + original_size = original_size or (height, width) + target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + control_type_repeat_factor = ( + batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1) + ) + + if isinstance(controlnet, ControlNetUnionModel): + control_type = ( + control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + ) + elif isinstance(controlnet, MultiControlNetUnionModel): + control_type = [ + _control_type.reshape(1, -1) + .to(self._execution_device, dtype=prompt_embeds.dtype) + .repeat(control_type_repeat_factor, 1) + for _control_type in control_type + ] + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + control_type=control_type, + control_type_idx=control_mode, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a26b9068afd1b0dc9cd2628df187fa4fb76104f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -0,0 +1,528 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ..stable_diffusion import FlaxStableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> from diffusers.utils import load_image, make_image_grid + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> # get canny image + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) + + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" + + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... controlnet=controlnet, + ... revision="flax", + ... dtype=jnp.float32, + ... ) + >>> params["controlnet"] = controlnet_params + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipe( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, + ... jit=True, + ... ).images + + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = make_image_grid(output_images, num_samples // 4, 4) + >>> output_images.save("generated_image.png") + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: FlaxDDIMScheduler | FlaxPNDMScheduler | FlaxLMSDiscreteScheduler | FlaxDPMSolverMultistepScheduler, + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def prepare_text_inputs(self, prompt: str | list[str]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Image.Image | list[Image.Image]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + return processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int, + guidance_scale: float, + latents: jnp.ndarray | None = None, + neg_prompt_ids: jnp.ndarray | None = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + image = jnp.concatenate([image] * 2) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int = 50, + guidance_scale: float | jnp.ndarray = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + controlnet_conditioning_scale: float | jnp.ndarray = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt_ids (`jnp.ndarray`): + The prompt or prompts to guide the image generation. + image (`jnp.ndarray`): + Array representing the ControlNet input condition to provide guidance to the `unet` for generation. + params (`Dict` or `FrozenDict`): + Dictionary containing the model parameters/weights. + prng_seed (`jax.Array`): + Array containing random number generator key. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + latents (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + array is generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.ndarray`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be + removed in a > future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + + height, width = image.shape[-2:] + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.array(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image.convert("RGB") + w, h = image.size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return image diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34c59795de328f2fa2cbe610625d866e4710d4a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3137737d7e1012ff51d06ef3abc22f3578af65f0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__pycache__/pipeline_hunyuandit_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__pycache__/pipeline_hunyuandit_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0da8a701b97f8f46fcb65a77284374f1854e9c07 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/__pycache__/pipeline_hunyuandit_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8882b561f0a141f60e38088855d76d3ed4ac918e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -0,0 +1,1050 @@ +# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline + import torch + + controlnet = HunyuanDiT2DControlNetModel.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16 + ) + + pipe = HunyuanDiTControlNetPipeline.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.to("cuda") + + from diffusers.utils import load_image + + cond_image = load_image( + "https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true" + ) + + ## You may also use English prompt as HunyuanDiT supports both English and Chinese + prompt = "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" + # prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere." + image = pipe( + prompt, + height=1024, + width=1024, + control_image=cond_image, + num_inference_steps=50, + ).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTControlNetPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (`~transformers.BertModel`, `~transformers.CLIPTextModel` | None): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer` | None): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + controlnet ([`HunyuanDiT2DControlNetModel`] or `list[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + controlnet: HunyuanDiT2DControlNetModel + | list[HunyuanDiT2DControlNetModel] + | tuple[HunyuanDiT2DControlNetModel] + | HunyuanDiT2DMultiControlNetModel, + text_encoder_2: T5EncoderModel | None = None, + tokenizer_2: T5Tokenizer | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = HunyuanDiT2DMultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + controlnet=controlnet, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int | None = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + prompt_attention_mask_2: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask_2: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = (1024, 1024), + target_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare control image + if isinstance(self.controlnet, HunyuanDiT2DControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = control_image * self.vae.config.scaling_factor + + elif isinstance(self.controlnet, HunyuanDiT2DMultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = control_image_ * self.vae.config.scaling_factor + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # controlnet(s) inference + control_block_samples = self.controlnet( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + )[0] + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + controlnet_block_samples=control_block_samples, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb61dc8e247555e654d2f84c035dff53328ea0d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"] + _import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [ + "StableDiffusion3ControlNetInpaintingPipeline" + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline + from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d8322f1995336be022cce3cf8d58d2b8f57b8f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/pipeline_stable_diffusion_3_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/pipeline_stable_diffusion_3_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a973cd5bc7b3e0acfae9b466d1092d8339827bea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/pipeline_stable_diffusion_3_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/pipeline_stable_diffusion_3_controlnet_inpainting.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/pipeline_stable_diffusion_3_controlnet_inpainting.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f4fe8ca9006fe31e203672f3098fe241f4340cb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/__pycache__/pipeline_stable_diffusion_3_controlnet_inpainting.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a787a34bdc01a539e0235d1e612f6701c0c53a87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -0,0 +1,1251 @@ +# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3ControlNetPipeline + >>> from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel + >>> from diffusers.utils import load_image + + >>> controlnet = SD3ControlNetModel.from_pretrained("InstantX/SD3-Controlnet-Canny", torch_dtype=torch.float16) + + >>> pipe = StableDiffusion3ControlNetPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> control_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ... ) + >>> prompt = "A bird in space" + >>> image = pipe( + ... prompt, control_image=control_image, height=1024, width=768, controlnet_conditioning_scale=0.7 + ... ).images[0] + >>> image.save("sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3ControlNetPipeline( + DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin +): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + controlnet ([`SD3ControlNetModel`] or `list[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + image_encoder (`SiglipVisionModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`SiglipImageProcessor`, *optional*): + Image processor for IP Adapter. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + controlnet: SD3ControlNetModel + | list[SD3ControlNetModel] + | tuple[SD3ControlNetModel] + | SD3MultiControlNetModel, + image_encoder: SiglipVisionModel | None = None, + feature_extractor: SiglipImageProcessor | None = None, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = SD3MultiControlNetModel(controlnet) + if isinstance(controlnet, SD3MultiControlNetModel): + for controlnet_model in controlnet.nets: + # for SD3.5 8b controlnet, it shares the pos_embed with the transformer + if ( + hasattr(controlnet_model.config, "use_pos_embed") + and controlnet_model.config.use_pos_embed is False + ): + pos_embed = controlnet_model._get_pos_embed_from_transformer(transformer) + controlnet_model.pos_embed = pos_embed.to(controlnet_model.dtype).to(controlnet_model.device) + elif isinstance(controlnet, SD3ControlNetModel): + if hasattr(controlnet.config, "use_pos_embed") and controlnet.config.use_pos_embed is False: + pos_embed = controlnet._get_pos_embed_from_transformer(transformer) + controlnet.pos_embed = pos_embed.to(controlnet.dtype).to(controlnet.device) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + controlnet_pooled_projections: torch.FloatTensor | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of controlnet input conditions. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + controlnet_config = ( + self.controlnet.config + if isinstance(self.controlnet, SD3ControlNetModel) + else self.controlnet.nets[0].config + ) + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Prepare control image + if controlnet_config.force_zeros_for_pooled_projection: + # instantx sd3 controlnet does not apply shift factor + vae_shift_factor = 0 + else: + vae_shift_factor = self.vae.config.shift_factor + if isinstance(self.controlnet, SD3ControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - vae_shift_factor) * self.vae.config.scaling_factor + elif isinstance(self.controlnet, SD3MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - vae_shift_factor) * self.vae.config.scaling_factor + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) + + if controlnet_config.force_zeros_for_pooled_projection: + # instantx sd3 controlnet used zero pooled projection + controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds + + if controlnet_config.joint_attention_dim is not None: + controlnet_encoder_hidden_states = prompt_embeds + else: + # SD35 official 8b controlnet does not use encoder_hidden_states + controlnet_encoder_hidden_states = None + + # 7. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet(s) inference + control_block_samples = self.controlnet( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=controlnet_encoder_hidden_states, + pooled_projections=controlnet_pooled_projections, + joint_attention_kwargs=self.joint_attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + return_dict=False, + )[0] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_samples, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..96f53b16cbe8fb64078aa2fdb6f187d8e526ef09 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -0,0 +1,1411 @@ +# Copyright 2025 Stability AI, The HuggingFace Team and The AlimamaCreative Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + SiglipImageProcessor, + SiglipModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image, check_min_version + >>> from diffusers.pipelines import StableDiffusion3ControlNetInpaintingPipeline + >>> from diffusers.models.controlnet_sd3 import SD3ControlNetModel + + >>> controlnet = SD3ControlNetModel.from_pretrained( + ... "alimama-creative/SD3-Controlnet-Inpainting", use_safetensors=True, extra_conditioning_channels=1 + ... ) + >>> pipe = StableDiffusion3ControlNetInpaintingPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.text_encoder.to(torch.float16) + >>> pipe.controlnet.to(torch.float16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/resolve/main/images/dog.png" + ... ) + >>> mask = load_image( + ... "https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting/resolve/main/images/dog_mask.png" + ... ) + >>> width = 1024 + >>> height = 1024 + >>> prompt = "A cat is sitting next to a puppy." + >>> generator = torch.Generator(device="cuda").manual_seed(24) + >>> res_image = pipe( + ... negative_prompt="deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW", + ... prompt=prompt, + ... height=height, + ... width=width, + ... control_image=image, + ... control_mask=mask, + ... num_inference_steps=28, + ... generator=generator, + ... controlnet_conditioning_scale=0.95, + ... guidance_scale=7, + ... ).images[0] + >>> res_image.save(f"sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3ControlNetInpaintingPipeline( + DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin +): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + controlnet ([`SD3ControlNetModel`] or `list[SD3ControlNetModel]` or [`SD3MultiControlNetModel`]): + Provides additional conditioning to the `transformer` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + image_encoder (`PreTrainedModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`BaseImageProcessor`, *optional*): + Image processor for IP Adapter. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + controlnet: SD3ControlNetModel + | list[SD3ControlNetModel] + | tuple[SD3ControlNetModel] + | SD3MultiControlNetModel, + image_encoder: SiglipModel = None, + feature_extractor: SiglipImageProcessor | None = None, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = SD3MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_resize=True, + do_convert_grayscale=True, + do_normalize=False, + do_binarize=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def check_inputs( + self, + height, + width, + image, + prompt, + prompt_2, + prompt_3, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, SD3MultiControlNetModel): + if isinstance(prompt, list) and len(prompt) > 1: + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, SD3ControlNetModel): + self.check_image(image, prompt, prompt_embeds) + elif isinstance(controlnet, SD3MultiControlNetModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + + # Check `controlnet_conditioning_scale` + if isinstance(controlnet, SD3MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(controlnet, SD3MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + def prepare_image_with_mask( + self, + image, + mask, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + # Prepare image + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + # Prepare mask + if isinstance(mask, torch.Tensor): + pass + else: + mask = self.mask_processor.preprocess(mask, height=height, width=width) + mask = mask.repeat_interleave(repeat_by, dim=0) + mask = mask.to(device=device, dtype=dtype) + + # Get masked image + masked_image = image.clone() + masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 + + # Encode to latents + image_latents = self.vae.encode(masked_image).latent_dist.sample() + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + image_latents = image_latents.to(dtype) + + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = 1 - mask + control_image = torch.cat([image_latents, mask], dim=1) + + if do_classifier_free_guidance and not guess_mode: + control_image = torch.cat([control_image] * 2) + + return control_image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_image: PipelineImageInput = None, + control_mask: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + controlnet_pooled_projections: torch.FloatTensor | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `control_mask` and repainted according to `prompt`). For both numpy array and + pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)`. If it is a numpy array or a list of arrays, the expected shape + should be `(B, H, W, C)` or `(H, W, C)`. + control_mask (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`. And + for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + controlnet_pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of controlnet input conditions. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, SD3MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + height, + width, + control_image, + prompt, + prompt_2, + prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Prepare control image + if isinstance(self.controlnet, SD3ControlNetModel): + control_image = self.prepare_image_with_mask( + image=control_image, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + latent_height, latent_width = control_image.shape[-2:] + + height = latent_height * self.vae_scale_factor + width = latent_width * self.vae_scale_factor + + elif isinstance(self.controlnet, SD3MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image_with_mask( + image=control_image_, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + control_images.append(control_image_) + + control_image = control_images + else: + assert ValueError("Controlnet not found. Please check the controlnet model.") + + if controlnet_pooled_projections is None: + controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) + else: + controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, SD3ControlNetModel) else keeps) + + # 7. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet(s) inference + control_block_samples = self.controlnet( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=controlnet_pooled_projections, + joint_attention_kwargs=self.joint_attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + return_dict=False, + )[0] + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_samples, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..978278b184f985a452f9d518a1d0eb4f271c74fd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__init__.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"] + _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline + from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85293aebe32d3b29d34cda0033ed5efd831d8317 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/pipeline_controlnet_xs.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/pipeline_controlnet_xs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31d90fa3373afcaee40ad52103d12284ca8efc6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/pipeline_controlnet_xs.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/pipeline_controlnet_xs_sd_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/pipeline_controlnet_xs_sd_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff36361f95851b8d5965cc877aae70a67db70bfe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/__pycache__/pipeline_controlnet_xs_sd_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py new file mode 100644 index 0000000000000000000000000000000000000000..9c81eb57e6c53356a7605a0499131b6b07026ca6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -0,0 +1,927 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetXSPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetControlNetXSModel, + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` and `controlnet_conditioning_scale` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.unet, UNetControlNetXSModel) + or is_compiled + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a5862b5b7a47849f7d77364f7544d9759cfbf6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -0,0 +1,1116 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> controlnet = ControlNetXSAdapter.from_pretrained( + ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionXLControlNetXSPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet-XS guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents. + controlnet ([`ControlNetXSAdapter`]): + A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel | UNetControlNetXSModel, + controlnet: ControlNetXSAdapter, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + if isinstance(unet, UNet2DConditionModel): + unet = UNetControlNetXSModel.from_unet(unet, controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` and ``controlnet_conditioning_scale`` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.unet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.unet, UNetControlNetXSModel) + or is_compiled + and isinstance(self.unet._orig_mod, UNetControlNetXSModel) + ): + self.check_image(image, prompt, prompt_embeds) + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + start, end = control_guidance_start, control_guidance_end + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip + def clip_skip(self): + return self._clip_skip + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. + control_guidance_start (`float`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is + returned, otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + + # 4. Prepare image + if isinstance(unet, UNetControlNetXSModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + is_controlnet_compiled = is_compiled_module(self.unet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if is_controlnet_compiled and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # predict the noise residual + apply_control = ( + i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end + ) + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=True, + apply_control=apply_control, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # manually for max memory savings + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc66cdf84b662e08d06949a9d14f00eb302627b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__init__.py @@ -0,0 +1,64 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cosmos2_5_predict"] = [ + "Cosmos2_5_PredictBasePipeline", + ] + _import_structure["pipeline_cosmos2_5_transfer"] = [ + "Cosmos2_5_TransferPipeline", + ] + _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"] + _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"] + _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] + _import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_cosmos2_5_predict import ( + Cosmos2_5_PredictBasePipeline, + ) + from .pipeline_cosmos2_5_transfer import Cosmos2_5_TransferPipeline + from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline + from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline + from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline + from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7dee965712a9d7a7e3e6afcf0a67c4ddfb0233 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_5_predict.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_5_predict.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2e2cf15421c62742675522bff96fa1984ef7ac Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_5_predict.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_5_transfer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_5_transfer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4454f70f35131b83b99fba28d44d6224a0dafd3d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_5_transfer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_text2image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_text2image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c374bd0169217f56b515a16755f743e6f4278e1e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_text2image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_video2world.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_video2world.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9268b9fe1265d165cca0d5325d06719c851c3846 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos2_video2world.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos_text2world.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos_text2world.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5840d8560f9d42718a753db9e1d5fd762ff61f0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos_text2world.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos_video2world.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos_video2world.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19de8a911435b424bbdb29e3f800140888ef8ee7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_cosmos_video2world.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23c846323518dafb63eb7edcce8fb6a2f4ee2001 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..cdea71a5ab93af1c7ec642d9b1dd862ec79a6255 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -0,0 +1,881 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import numpy as np +import torch +import torchvision +import torchvision.transforms +import torchvision.transforms.functional +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2_5_PredictBasePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "nvidia/Cosmos-Predict2.5-2B" + >>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + ... model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Common negative prompt reused across modes. + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + + >>> # Text2World: generate a 93-frame world video from text only. + >>> prompt = ( + ... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights " + ... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh " + ... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet " + ... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. " + ... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow " + ... "advance of traffic through the frosty city corridor." + ... ) + >>> video = pipe( + ... image=None, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "text2world.mp4", fps=16) + + >>> # Image2World: condition on a single image and generate a 93-frame world video. + >>> prompt = ( + ... "A high-definition video captures the precision of robotic welding in an industrial setting. " + ... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. " + ... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid " + ... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring " + ... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a " + ... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video " + ... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. " + ... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. " + ... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with " + ... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation." + ... ) + >>> image = load_image( + ... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg" + ... ) + >>> video = pipe( + ... image=image, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "image2world.mp4", fps=16) + + >>> # Video2World: condition on an input clip and predict a 93-frame world video. + >>> prompt = ( + ... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles " + ... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the " + ... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green " + ... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. " + ... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along " + ... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame " + ... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet " + ... "steady pace of the construction activity." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4" + ... ) + >>> video = pipe( + ... image=None, + ... video=input_video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "video2world.mp4", fps=16) + + >>> # To produce an image instead of a world (video) clip, set num_frames=1 and + >>> # save the first frame: pipe(..., num_frames=1).frames[0][0]. + ``` +""" + + +class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): + r""" + Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = ( + input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: torch.Tensor | None, + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if num_frames_in == 0: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) + if needs_preprocessing: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1 + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + video: list[PipelineImageInput] | None = None, + prompt: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + num_latent_conditional_frames: int = 2, + ): + r""" + The call function to the pipeline for generation. Supports three modes: + + - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. + + Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the + above in "*2Image mode"). + + Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional single image for Image2World conditioning. Must be `None` when `video` is provided. + video (`list[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + num_latent_conditional_frames (`int`, defaults to `2`): + Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames + extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1 + for Image2World-like behavior (single frame conditioning). + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + elif video is None: + video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + else: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") + + if num_latent_conditional_frames not in [1, 2]: + raise ValueError( + f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}" + ) + + frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1 + + total_input_frames = len(video) + + if total_input_frames < frames_to_extract: + raise ValueError( + f"Input video has only {total_input_frames} frames but Video2World requires at least " + f"{frames_to_extract} frames for conditioning." + ) + + num_frames_in = frames_to_extract + + assert video is not None + video = self.video_processor.preprocess_video(video, height, width) + + # For Video2World: extract last frames_to_extract frames from input, then pad + if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]: + video = video[:, :, -num_frames_in:, :, :] + + num_frames_out = num_frames + + if video.shape[2] < num_frames_out: + n_pad_frames = num_frames_out - video.shape[2] + last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W] + pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] + video = torch.cat((video, pad_frames), dim=2) + + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames_in=num_frames_in, + num_frames_out=num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + cond_mask = cond_mask.to(transformer_dtype) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + noise_pred = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + noise_pred_neg = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = self.latents_mean.to(latents.device, latents.dtype) + latents_std = self.latents_std.to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + video = self._match_num_frames(video, num_frames) + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) + + def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: + if target_num_frames <= 0 or video.shape[2] == target_num_frames: + return video + + frames_per_latent = max(self.vae_scale_factor_temporal, 1) + video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) + + current_frames = video.shape[2] + if current_frames < target_num_frames: + pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) + video = torch.cat([video, pad], dim=2) + elif current_frames > target_num_frames: + video = video[:, :, :target_num_frames] + + return video diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..bbe38c44355e71f97c6d93d10eaa09c84eb3677c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py @@ -0,0 +1,1012 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosControlNetModel, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int): + n_pad_frames = num_frames - video.shape[2] + if n_pad_frames > 0: + last_frame = video[:, :, -1:, :, :] + video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2) + elif num_frames < video.shape[2]: + video = video[:, :, :num_frames, :, :] + return video + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +DEFAULT_NEGATIVE_PROMPT = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import cv2 + >>> import numpy as np + >>> import torch + >>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel + >>> from diffusers.utils import export_to_video, load_video + + >>> model_id = "nvidia/Cosmos-Transfer2.5-2B" + >>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur) + >>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge") + >>> pipe = Cosmos2_5_TransferPipeline.from_pretrained( + ... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Video2World with edge control: Generate video guided by edge maps extracted from input video. + >>> prompt = ( + ... "The video is a demonstration of robotic manipulation, likely in a laboratory or testing environment. It" + ... "features two robotic arms interacting with a piece of blue fabric. The setting is a room with a beige" + ... "couch in the background, providing a neutral backdrop for the robotic activity. The robotic arms are" + ... "positioned on either side of the fabric, which is placed on a yellow cushion. The left robotic arm is" + ... "white with a black gripper, while the right arm is black with a more complex, articulated gripper. At the" + ... "beginning, the fabric is laid out on the cushion. The left robotic arm approaches the fabric, its gripper" + ... "opening and closing as it positions itself. The right arm remains stationary initially, poised to assist." + ... "As the video progresses, the left arm grips the fabric, lifting it slightly off the cushion. The right arm" + ... "then moves in, its gripper adjusting to grasp the opposite side of the fabric. Both arms work in" + ... "coordination, lifting and holding the fabric between them. The fabric is manipulated with precision," + ... "showcasing the dexterity and control of the robotic arms. The camera remains static throughout, focusing" + ... "on the interaction between the robotic arms and the fabric, allowing viewers to observe the detailed" + ... "movements and coordination involved in the task." + ... ) + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-transfer2.5/raw/refs/heads/main/assets/robot_example/robot_input.mp4" + ... ) + >>> num_frames = 93 + + >>> # Extract edge maps from the input video using Canny edge detection + >>> edge_maps = [ + ... cv2.Canny(cv2.cvtColor(np.array(frame.convert("RGB")), cv2.COLOR_RGB2BGR), 100, 200) + ... for frame in input_video[:num_frames] + ... ] + >>> edge_maps = np.stack(edge_maps)[None] # (T, H, W) -> (1, T, H, W) + >>> controls = torch.from_numpy(edge_maps).expand(3, -1, -1, -1) # (1, T, H, W) -> (3, T, H, W) + >>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)] + >>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30) + + >>> # Transfer inference with controls. + >>> video = pipe( + ... controls=controls, + ... controls_conditioning_scale=1.0, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=num_frames, + ... ).frames[0] + >>> export_to_video(video, "edge_controlled_video.mp4", fps=30) + ``` +""" + + +class Cosmos2_5_TransferPipeline(DiffusionPipeline): + r""" + Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Transfer2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + controlnet ([`CosmosControlNetModel`]): + ControlNet used to condition generation on control inputs. + """ + + model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + controlnet: CosmosControlNetModel, + safety_checker: Optional[CosmosSafetyChecker] = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = ( + input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + num_cond_latent_frames: int = 0, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if latents is not None: + if latents.shape[1:] != shape[1:]: + raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.") + latents = latents.to(device=device, dtype=dtype) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if num_frames_in == 0: + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + num_ar_conditional_frames=None, + num_ar_latent_conditional_frames=None, + num_frames_per_chunk=None, + num_frames=None, + conditional_frame_timestep=0.1, + ): + if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}." + ) + + if num_frames is not None and num_frames <= 0: + raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.") + + if conditional_frame_timestep < 0 or conditional_frame_timestep > 1: + raise ValueError( + "`conditional_frame_timestep` has to be a float in the [0, 1] interval but is " + f"{conditional_frame_timestep}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None: + raise ValueError( + "Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both." + ) + if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None: + raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.") + if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0: + raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.") + if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0: + raise ValueError("`num_ar_conditional_frames` must be >= 0.") + + if num_ar_latent_conditional_frames is not None: + num_ar_conditional_frames = max( + 0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1 + ) + + min_chunk_len = self.vae_scale_factor_temporal + 1 + if num_frames_per_chunk < min_chunk_len: + logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len") + num_frames_per_chunk = min_chunk_len + + max_frames_by_rope = None + if getattr(self.transformer.config, "max_size", None) is not None: + max_frames_by_rope = max( + size // patch + for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size) + ) + if num_frames_per_chunk > max_frames_by_rope: + raise ValueError( + f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). " + "Please reduce `num_frames_per_chunk`." + ) + + if num_ar_conditional_frames >= num_frames_per_chunk: + raise ValueError( + f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation." + ) + + return num_frames_per_chunk + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + controls: PipelineImageInput | List[PipelineImageInput], + controls_conditioning_scale: Union[float, List[float]] = 1.0, + prompt: Union[str, List[str]] | None = None, + negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT, + height: int = 704, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_frames_per_chunk: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 3.0, + num_videos_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + num_ar_conditional_frames: Optional[int] = 1, + num_ar_latent_conditional_frames: Optional[int] = None, + ): + r""" + `controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps + are pre-computed. + + Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None + (default) then the number of output frames will match the input `controls`. + + Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per + denoising loop. In addition, when auto-regressive inference is performed, the previous + `num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising + inference loops. + + Args: + controls (`PipelineImageInput`, `List[PipelineImageInput]`): + Control image or video input used by the ControlNet. + controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`): + The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. If not provided, this will be determined based on the + aspect ratio of the input and the provided height. + num_frames (`int`, *optional*): + Number of output frames. Defaults to `None` to output the same number of frames as the input + `controls`. + num_inference_steps (`int`, defaults to `36`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `3.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to + tweak the same generation with different prompts. If not provided, a latents tensor is generated by + sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + num_ar_conditional_frames (`int`, *optional*, defaults to `1`): + Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the + second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`. + + This is only used when auto-regressive inference is performed, i.e. when the number of frames in + controls is > num_frames_per_chunk + num_ar_latent_conditional_frames (`int`, *optional*): + Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e. + for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`. + + This is only used when auto-regressive inference is performed, i.e. when the number of frames in + controls is > num_frames_per_chunk + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if width is None: + frame = controls[0] if isinstance(controls, list) else controls + if isinstance(frame, list): + frame = frame[0] + if isinstance(frame, (torch.Tensor, np.ndarray)): + if frame.ndim == 5: + frame = frame[0, 0] + elif frame.ndim == 4: + frame = frame[0] + + if isinstance(frame, PIL.Image.Image): + width = int((height + 16) * (frame.width / frame.height)) + else: + if frame.ndim != 3: + raise ValueError("`controls` must contain 3D frames in CHW format.") + width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W + + num_frames_per_chunk = self.check_inputs( + prompt, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + num_ar_conditional_frames, + num_ar_latent_conditional_frames, + num_frames_per_chunk, + num_frames, + conditional_frame_timestep, + ) + + if num_ar_latent_conditional_frames is not None: + num_cond_latent_frames = num_ar_latent_conditional_frames + num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1) + else: + num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + if getattr(self.transformer.config, "img_context_dim_in", None): + img_context = torch.zeros( + batch_size, + self.transformer.config.img_context_num_tokens, + self.transformer.config.img_context_dim_in, + device=prompt_embeds.device, + dtype=transformer_dtype, + ) + + if num_videos_per_prompt > 1: + img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0) + + encoder_hidden_states = (prompt_embeds, img_context) + neg_encoder_hidden_states = (negative_prompt_embeds, img_context) + else: + encoder_hidden_states = prompt_embeds + neg_encoder_hidden_states = negative_prompt_embeds + + control_video = self.video_processor.preprocess_video(controls, height, width) + if control_video.shape[0] != batch_size: + if control_video.shape[0] == 1: + control_video = control_video.repeat(batch_size, 1, 1, 1, 1) + else: + raise ValueError( + f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}." + ) + + num_frames_out = control_video.shape[2] + if num_frames is not None: + num_frames_out = min(num_frames_out, num_frames) + + control_video = _maybe_pad_or_trim_video(control_video, num_frames_out) + + # chunk information + num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1 + chunk_stride = num_frames_per_chunk - num_ar_conditional_frames + chunk_idxs = [ + (start_idx, min(start_idx + num_frames_per_chunk, num_frames_out)) + for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride) + ] + + video_chunks = [] + latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device) + latents_std = self.latents_std.to(dtype=vae_dtype, device=device) + + def decode_latents(latents): + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0] + return video + + latents_arg = latents + initial_num_cond_latent_frames = 0 + latent_chunks = [] + num_chunks = len(chunk_idxs) + total_steps = num_inference_steps * num_chunks + with self.progress_bar(total=total_steps) as progress_bar: + for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs): + if chunk_idx == 0: + prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype) + prev_output = self.video_processor.preprocess_video(prev_output, height, width) + else: + prev_output = video_chunks[-1].clone() + if num_ar_conditional_frames > 0: + prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:] + prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space + else: + prev_output.fill_(-1) + + chunk_video = prev_output.to(device=device, dtype=vae_dtype) + chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk) + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=chunk_video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.transformer.config.in_channels - 1, + height=height, + width=width, + num_frames_in=chunk_video.shape[2], + num_frames_out=num_frames_per_chunk, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + num_cond_latent_frames=initial_num_cond_latent_frames + if chunk_idx == 0 + else num_cond_latent_frames, + latents=latents_arg, + ) + cond_mask = cond_mask.to(transformer_dtype) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to( + device=device, dtype=self.vae.dtype + ) + chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk) + if isinstance(generator, list): + controls_latents = [ + retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i]) + for i in range(chunk_control_video.shape[0]) + ] + else: + controls_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) + for vid in chunk_control_video + ] + controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype) + + controls_latents = (controls_latents - latents_mean) / latents_std + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + gt_velocity = (latents - cond_latent) * cond_mask + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=encoder_hidden_states, + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=encoder_hidden_states, + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred_neg = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + video_chunks.append(decode_latents(latents).detach().cpu()) + latent_chunks.append(latents.detach().cpu()) + + self._current_timestep = None + + if not output_type == "latent": + video_chunks = [ + chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk + for chunk_idx, chunk in enumerate(video_chunks) + ] + video = torch.cat(video_chunks, dim=2) + video = video[:, :, :num_frames_out, ...] + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + if vid is None: + video_batch.append(np.zeros_like(video[0])) + else: + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + latent_chunks = [ + chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk + for chunk_idx, chunk in enumerate(latent_chunks) + ] + video = torch.cat(latent_chunks, dim=2) + video = video[:, :, :latent_T, ...] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py new file mode 100644 index 0000000000000000000000000000000000000000..f24e19eea0d4bab2e404eb1135d0d83dd8fa0c53 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py @@ -0,0 +1,679 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosImagePipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2TextToImagePipeline + + >>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image + >>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image" + >>> pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." + >>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + + >>> output = pipe( + ... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1) + ... ).images[0] + >>> output.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Cosmos2TextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.sigma_max = 80.0 + self.sigma_min = 0.002 + self.sigma_data = 1.0 + self.final_sigmas_type = "sigma_min" + if self.scheduler is not None: + self.scheduler.register_to_config( + sigma_max=self.sigma_max, + sigma_min=self.sigma_min, + sigma_data=self.sigma_data, + final_sigmas_type=self.final_sigmas_type, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_images_per_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 16, + height: int = 768, + width: int = 1360, + num_frames: int = 1, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents * self.scheduler.config.sigma_max + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 768, + width: int = 1360, + num_inference_steps: int = 35, + guidance_scale: float = 7.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `768`): + The height in pixels of the generated image. + width (`int`, defaults to `1360`): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~CosmosImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_frames = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas) + if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min": + # Replace the last sigma (which is zero) with the minimum sigma value + self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2] + + # 5. Prepare latent variables + transformer_dtype = self.transformer.dtype + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + current_sigma = self.scheduler.sigmas[i] + + current_t = current_sigma / (current_sigma + 1) + c_in = 1 - current_t + c_skip = 1 - current_t + c_out = -current_t + timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1] + + latent_model_input = latents * c_in + latent_model_input = latent_model_input.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond) + + noise_pred = (latents - noise_pred) / current_sigma + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std / self.scheduler.config.sigma_data + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + image = [batch[0] for batch in video] + if isinstance(video, torch.Tensor): + image = torch.stack(image) + elif isinstance(video, np.ndarray): + image = np.stack(image) + else: + image = latents[:, :, 0] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CosmosImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb13af06637fec9a7b3b81a874b3c18f08d93dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py @@ -0,0 +1,798 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2VideoToWorldPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World + >>> model_id = "nvidia/Cosmos-Predict2-2B-Video2World" + >>> pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." + >>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png" + ... ) + + >>> video = pipe( + ... image=image, prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1) + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Cosmos2VideoToWorldPipeline(DiffusionPipeline): + r""" + Pipeline for video-to-world generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.sigma_max = 80.0 + self.sigma_min = 0.002 + self.sigma_data = 1.0 + self.final_sigmas_type = "sigma_min" + if self.scheduler is not None: + self.scheduler.register_to_config( + sigma_max=self.sigma_max, + sigma_min=self.sigma_min, + sigma_data=self.sigma_data, + final_sigmas_type=self.final_sigmas_type, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: torch.Tensor, + batch_size: int, + num_channels_latents: 16, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + do_classifier_free_guidance: bool = True, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_cond_frames = video.size(2) + if num_cond_frames >= num_frames: + # Take the last `num_frames` frames for conditioning + num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + video = video[:, :, -num_frames:] + else: + num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1 + num_padding_frames = num_frames - num_cond_frames + last_frame = video[:, :, -1:] + padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1) + video = torch.cat([video, padding], dim=2) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.config.sigma_data + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latents = latents * self.scheduler.config.sigma_max + + padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, :num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + uncond_indicator = uncond_mask = None + if do_classifier_free_guidance: + uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + uncond_indicator[:, :, :num_cond_latent_frames] = 1.0 + uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding + + return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + video: list[PipelineImageInput] = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 35, + guidance_scale: float = 7.0, + fps: int = 16, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + sigma_conditioning: float = 0.0001, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + The image to be used as a conditioning input for the video generation. + video (`list[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + The video to be used as a conditioning input for the video generation. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + fps (`int`, defaults to `16`): + The frames per second of the generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + sigma_conditioning (`float`, defaults to `0.0001`): + The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be + set to a small value close to zero. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas) + if self.scheduler.config.final_sigmas_type == "sigma_min": + # Replace the last sigma (which is zero) with the minimum sigma value + self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2] + + # 5. Prepare latent variables + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + if image is not None: + video = self.video_processor.preprocess(image, height, width).unsqueeze(2) + else: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + self.do_classifier_free_guidance, + torch.float32, + device, + generator, + latents, + ) + unconditioning_latents = None + + cond_mask = cond_mask.to(transformer_dtype) + if self.do_classifier_free_guidance: + uncond_mask = uncond_mask.to(transformer_dtype) + unconditioning_latents = conditioning_latents + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device) + t_conditioning = sigma_conditioning / (sigma_conditioning + 1) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + current_sigma = self.scheduler.sigmas[i] + + current_t = current_sigma / (current_sigma + 1) + c_in = 1 - current_t + c_skip = 1 - current_t + c_out = -current_t + timestep = current_t.view(1, 1, 1, 1, 1).expand( + latents.size(0), -1, latents.size(2), -1, -1 + ) # [B, 1, T, 1, 1] + + cond_latent = latents * c_in + cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent + cond_latent = cond_latent.to(transformer_dtype) + cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep + cond_timestep = cond_timestep.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=cond_timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) + noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred + + if self.do_classifier_free_guidance: + uncond_latent = latents * c_in + uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent + uncond_latent = uncond_latent.to(transformer_dtype) + uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep + uncond_timestep = uncond_timestep.to(transformer_dtype) + + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent, + timestep=uncond_timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + condition_mask=uncond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype) + noise_pred_uncond = ( + uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond + ) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond) + + noise_pred = (latents - noise_pred) / current_sigma + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py new file mode 100644 index 0000000000000000000000000000000000000000..e144d62d5933fa7a59c4ab63f330c20928e66053 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py @@ -0,0 +1,670 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel +from ...schedulers import EDMEulerScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CosmosTextToWorldPipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World" + >>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect." + + >>> output = pipe(prompt=prompt).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CosmosTextToWorldPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLCosmos`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLCosmos, + scheduler: EDMEulerScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: 16, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents * self.scheduler.config.sigma_max + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + fps: int = 30, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `36`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + fps (`int`, defaults to `30`): + The frames per second of the generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + + # 5. Prepare latent variables + transformer_dtype = self.transformer.dtype + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(transformer_dtype) + + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + padding_mask=padding_mask, + return_dict=False, + )[0] + + sample = latents + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = torch.cat([noise_pred_uncond, noise_pred]) + sample = torch.cat([sample, sample]) + + # pred_original_sample (x0) + noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1] + self.scheduler._step_index -= 1 + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # pred_sample (eps) + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + if self.vae.config.latents_mean is not None: + latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std + latents_mean = ( + torch.tensor(latents_mean) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents_std = ( + torch.tensor(latents_std) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean + else: + latents = latents / self.scheduler.config.sigma_data + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py new file mode 100644 index 0000000000000000000000000000000000000000..377c3c05d2849b9f12796160cc46ddf5e35db66e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py @@ -0,0 +1,832 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel +from ...schedulers import EDMEulerScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + +EXAMPLE_DOC_STRING = """ + Examples: + Image conditioning: + + ```python + >>> import torch + >>> from diffusers import CosmosVideoToWorldPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World" + >>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" + ... ) + + >>> video = pipe(image=image, prompt=prompt).frames[0] + >>> export_to_video(video, "output.mp4", fps=30) + ``` + + Video conditioning: + + ```python + >>> import torch + >>> from diffusers import CosmosVideoToWorldPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World" + >>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.transformer = torch.compile(pipe.transformer) + >>> pipe.to("cuda") + + >>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" + ... )[ + ... :21 + ... ] # This example uses only the first 21 frames + + >>> video = pipe(video=video, prompt=prompt).frames[0] + >>> export_to_video(video, "output.mp4", fps=30) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class CosmosVideoToWorldPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-world and video-to-world generation using [Cosmos + Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Cosmos uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLCosmos`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLCosmos, + scheduler: EDMEulerScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + return_length=True, + return_offsets_mapping=False, + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + lengths = prompt_attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + + return prompt_embeds + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + video: torch.Tensor, + batch_size: int, + num_channels_latents: 16, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + do_classifier_free_guidance: bool = True, + input_frames_guidance: bool = False, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_cond_frames = video.size(2) + if num_cond_frames >= num_frames: + # Take the last `num_frames` frames for conditioning + num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + video = video[:, :, -num_frames:] + else: + num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1 + num_padding_frames = num_frames - num_cond_frames + padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4)) + video = torch.cat([video, padding], dim=2) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + if self.vae.config.latents_mean is not None: + latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std + latents_mean = ( + torch.tensor(latents_mean) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)] + .to(init_latents) + ) + latents_std = ( + torch.tensor(latents_std) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)] + .to(init_latents) + ) + init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std + else: + init_latents = init_latents * self.scheduler.config.sigma_data + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latents = latents * self.scheduler.config.sigma_max + + padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, :num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + uncond_indicator = uncond_mask = None + if do_classifier_free_guidance: + uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + uncond_indicator[:, :, :num_cond_latent_frames] = 1.0 + uncond_mask = zeros_padding + if not input_frames_guidance: + uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding + + return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + video=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if image is None and video is None: + raise ValueError("Either `image` or `video` has to be provided.") + if image is not None and video is not None: + raise ValueError("Only one of `image` or `video` has to be provided.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + video: list[PipelineImageInput] = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 704, + width: int = 1280, + num_frames: int = 121, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + input_frames_guidance: bool = False, + augment_sigma: float = 0.001, + fps: int = 30, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `36`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + fps (`int`, defaults to `30`): + The frames per second of the generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + self.safety_checker.to("cpu") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) + + # 5. Prepare latent variables + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + if image is not None: + video = self.video_processor.preprocess(image, height, width).unsqueeze(2) + else: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + self.do_classifier_free_guidance, + input_frames_guidance, + torch.float32, + device, + generator, + latents, + ) + cond_mask = cond_mask.to(transformer_dtype) + if self.do_classifier_free_guidance: + uncond_mask = uncond_mask.to(transformer_dtype) + + augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32) + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(transformer_dtype) + + current_sigma = self.scheduler.sigmas[i] + is_augment_sigma_greater = augment_sigma >= current_sigma + + c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma) + c_in_original = self.scheduler._get_conditioning_c_in(current_sigma) + + current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator + cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32) + cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None] + cond_latent = cond_latent * c_in_augment / c_in_original + cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents + cond_latent = self.scheduler.scale_model_input(cond_latent, t) + cond_latent = cond_latent.to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + + sample = latents + if self.do_classifier_free_guidance: + current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator + uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32) + uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None] + uncond_latent = uncond_latent * c_in_augment / c_in_original + uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents + uncond_latent = self.scheduler.scale_model_input(uncond_latent, t) + uncond_latent = uncond_latent.to(transformer_dtype) + + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + condition_mask=uncond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + noise_pred = torch.cat([noise_pred_uncond, noise_pred]) + sample = torch.cat([sample, sample]) + + # pred_original_sample (x0) + noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1] + self.scheduler._step_index -= 1 + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred_uncond = ( + current_uncond_indicator * conditioning_latents + + (1 - current_uncond_indicator) * noise_pred_uncond + ) + noise_pred_cond = ( + current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond + ) + noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = ( + current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred + ) + + # pred_sample (eps) + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + if self.vae.config.latents_mean is not None: + latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std + latents_mean = ( + torch.tensor(latents_mean) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents_std = ( + torch.tensor(latents_std) + .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] + .to(latents) + ) + latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean + else: + latents = latents / self.scheduler.config.sigma_data + video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0] + + if self.safety_checker is not None: + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + self.safety_checker.to("cpu") + else: + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..1ded292f8dfbdd7e938ebdcea1bc9fd03b545b84 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/cosmos/pipeline_output.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from diffusers.utils import BaseOutput, get_logger + + +logger = get_logger(__name__) + + +@dataclass +class CosmosPipelineOutput(BaseOutput): + r""" + Output class for Cosmos any-to-world/video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +@dataclass +class CosmosImagePipelineOutput(BaseOutput): + """ + Output class for Cosmos any-to-image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3e466dfa65b2e9890451607959ed45d092cae7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_dance_diffusion": ["DanceDiffusionPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_dance_diffusion import DanceDiffusionPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7f1d1eca642fbd40c56e01069adb22164b1e86e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..610f27c01f67c45f20d89308e1a85e026e4777d9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/__pycache__/pipeline_dance_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8f8106061d4332881bec919ea1137516b7cf66 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -0,0 +1,167 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...models import UNet1DModel +from ...schedulers import SchedulerMixin +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class DanceDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + r""" + Pipeline for audio generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet1DModel`]): + A `UNet1DModel` to denoise the encoded audio. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`IPNDMScheduler`]. + """ + + _last_supported_version = "0.33.1" + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 100, + generator: torch.Generator | list[torch.Generator] | None = None, + audio_length_in_s: float | None = None, + return_dict: bool = True, + ) -> AudioPipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of audio samples to generate. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher-quality audio sample at + the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): + The length of the generated audio sample in seconds. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + + Example: + + ```py + from diffusers import DiffusionPipeline + from scipy.io.wavfile import write + + model_id = "harmonai/maestro-150k" + pipe = DiffusionPipeline.from_pretrained(model_id) + pipe = pipe.to("cuda") + + audios = pipe(audio_length_in_s=4.0).audios + + # To save locally + for i, audio in enumerate(audios): + write(f"maestro_test_{i}.wav", pipe.unet.sample_rate, audio.transpose()) + + # To display in google colab + import IPython.display as ipd + + for audio in audios: + display(ipd.Audio(audio, rate=pipe.unet.sample_rate)) + ``` + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate + + sample_size = audio_length_in_s * self.unet.config.sample_rate + + down_scale_factor = 2 ** len(self.unet.up_blocks) + if sample_size < 3 * down_scale_factor: + raise ValueError( + f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" + f" {3 * down_scale_factor / self.unet.config.sample_rate}." + ) + + original_sample_size = int(sample_size) + if sample_size % down_scale_factor != 0: + sample_size = ( + (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 + ) * down_scale_factor + logger.info( + f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" + f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" + " process." + ) + sample_size = int(sample_size) + + dtype = next(self.unet.parameters()).dtype + shape = (batch_size, self.unet.config.in_channels, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + audio = randn_tensor(shape, generator=generator, device=self._execution_device, dtype=dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, device=audio.device) + self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(audio, t).sample + + # 2. compute previous audio sample: x_t -> t_t-1 + audio = self.scheduler.step(model_output, t, audio).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + audio = audio.clamp(-1, 1).float().cpu().numpy() + + audio = audio[:, :, :original_sample_size] + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddim/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9eede47c897370a23c47c05291690881c987025 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddim/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_ddim": ["DDIMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_ddim import DDIMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddim/pipeline_ddim.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddim/pipeline_ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..dc92b34a7565c0c3bf8f17714867d5278c119042 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddim/pipeline_ddim.py @@ -0,0 +1,165 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...models import UNet2DModel +from ...schedulers import DDIMScheduler +from ...utils import is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class DDIMPipeline(DiffusionPipeline): + r""" + Pipeline for image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + + # make sure scheduler can always be converted to DDIM + scheduler = DDIMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + eta: float = 0.0, + num_inference_steps: int = 50, + use_clipped_model_output: bool | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` + corresponds to DDIM and `1` corresponds to DDPM. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + use_clipped_model_output (`bool`, *optional*, defaults to `None`): + If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed + downstream to the scheduler (use `None` for schedulers which don't support this argument). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDIMPipeline + >>> import PIL.Image + >>> import numpy as np + + >>> # load model and scheduler + >>> pipe = DDIMPipeline.from_pretrained("fusing/ddim-lsun-bedroom") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe(eta=0.0, num_inference_steps=50) + + >>> # process image to PIL + >>> image_processed = image.cpu().permute(0, 2, 3, 1) + >>> image_processed = (image_processed + 1.0) * 127.5 + >>> image_processed = image_processed.numpy().astype(np.uint8) + >>> image_pil = PIL.Image.fromarray(image_processed[0]) + + >>> # save image + >>> image_pil.save("test.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step( + model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddpm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb41dd1dcf642c791f3d7b0d985efcaf3e4a2c22 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddpm/__init__.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, +) + + +_import_structure = {"pipeline_ddpm": ["DDPMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_ddpm import DDPMPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddpm/pipeline_ddpm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddpm/pipeline_ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4796cbea1f3fe00f4a3417bcff96ea09e104a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -0,0 +1,139 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...models import UNet2DModel +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class DDPMPipeline(DiffusionPipeline): + r""" + Pipeline for image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + num_inference_steps: int = 1000, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDPMPipeline + + >>> # load model and scheduler + >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + + >>> # save image + >>> image.save("ddpm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79aab1fb186a857dd0a3353c4b5905b4595b5b7b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__init__.py @@ -0,0 +1,85 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = { + "timesteps": [ + "fast27_timesteps", + "smart100_timesteps", + "smart185_timesteps", + "smart27_timesteps", + "smart50_timesteps", + "super100_timesteps", + "super27_timesteps", + "super40_timesteps", + ] +} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_if"] = ["IFPipeline"] + _import_structure["pipeline_if_img2img"] = ["IFImg2ImgPipeline"] + _import_structure["pipeline_if_img2img_superresolution"] = ["IFImg2ImgSuperResolutionPipeline"] + _import_structure["pipeline_if_inpainting"] = ["IFInpaintingPipeline"] + _import_structure["pipeline_if_inpainting_superresolution"] = ["IFInpaintingSuperResolutionPipeline"] + _import_structure["pipeline_if_superresolution"] = ["IFSuperResolutionPipeline"] + _import_structure["pipeline_output"] = ["IFPipelineOutput"] + _import_structure["safety_checker"] = ["IFSafetyChecker"] + _import_structure["watermark"] = ["IFWatermarker"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_if import IFPipeline + from .pipeline_if_img2img import IFImg2ImgPipeline + from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline + from .pipeline_if_inpainting import IFInpaintingPipeline + from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline + from .pipeline_if_superresolution import IFSuperResolutionPipeline + from .pipeline_output import IFPipelineOutput + from .safety_checker import IFSafetyChecker + from .timesteps import ( + fast27_timesteps, + smart27_timesteps, + smart50_timesteps, + smart100_timesteps, + smart185_timesteps, + super27_timesteps, + super40_timesteps, + super100_timesteps, + ) + from .watermark import IFWatermarker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28d6c06d583aaaddbbe7a88c313add44366c5295 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d6820c6f2f46f29ad7e0f345893ab567b121e69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66ee62b9c42d363244d06447bcd99d098ac06e5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcaf31f084499248eec3dd3dd2d647b878c642ae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_img2img_superresolution.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65990e1c3de95fa87078102084e4868d43220074 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc7208633d5d74e9842b35b645575fbf91ed8686 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_inpainting_superresolution.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ccc773de4bd1581e40bcc6398e97e871a497d24 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_if_superresolution.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94dc8f2733501077664a5c1510e983275c1860e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..100305388b5c1d49878a73c1eadf4749e0a250da Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/safety_checker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc9ce117522148eb9cf8bf90f569183a75fea654 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/timesteps.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbf348ce6289664cb80ca3b929b1ebb459046aad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/__pycache__/watermark.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if.py new file mode 100644 index 0000000000000000000000000000000000000000..b8c70fc6528cfbea087ab1e01fcc0b0218149cb7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -0,0 +1,786 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> safety_modules = { + ... "feature_extractor": pipe.feature_extractor, + ... "safety_checker": pipe.safety_checker, + ... "watermarker": pipe.watermarker, + ... } + >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 + ... ) + >>> super_res_2_pipe.enable_model_cpu_offload() + + >>> image = super_res_2_pipe( + ... prompt=prompt, + ... image=image, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: CLIPImageProcessor | None + safety_checker: IFSafetyChecker | None + + watermarker: IFWatermarker | None + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: IFSafetyChecker | None, + feature_extractor: CLIPImageProcessor | None, + watermarker: IFWatermarker | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_inference_steps: int = 100, + timesteps: list[int] = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare intermediate images + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + self.unet.config.in_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + if XLA_AVAILABLE: + xm.mark_step() + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..3dadc63f4952df555a01260a6c16f9eebf8ec7ba --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -0,0 +1,910 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: CLIPImageProcessor | None + safety_checker: IFSafetyChecker | None + + watermarker: IFWatermarker | None + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: IFSafetyChecker | None, + feature_extractor: CLIPImageProcessor | None, + watermarker: IFWatermarker | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PIL.Image.Image + | torch.Tensor + | np.ndarray + | list[PIL.Image.Image] + | list[torch.Tensor] + | list[np.ndarray] = None, + strength: float = 0.7, + num_inference_steps: int = 80, + timesteps: list[int] = None, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.7): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 80): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 10.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + if XLA_AVAILABLE: + xm.mark_step() + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..4839a086046257990df860bff0ce1713b1110cdd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -0,0 +1,1027 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image.resize((768, 512)) + + >>> pipe = IFImg2ImgPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A fantasy landscape in style minecraft" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", + ... text_encoder=None, + ... variant="fp16", + ... torch_dtype=torch.float16, + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: CLIPImageProcessor | None + safety_checker: IFSafetyChecker | None + + watermarker: IFWatermarker | None + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: IFSafetyChecker | None, + feature_extractor: CLIPImageProcessor | None, + watermarker: IFWatermarker | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet is not None and unet.config.in_channels != 6: + logger.warning( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None + ): + _, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + image = self.scheduler.add_noise(image, noise, timestep) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image | np.ndarray | torch.Tensor, + original_image: PIL.Image.Image + | torch.Tensor + | np.ndarray + | list[PIL.Image.Image] + | list[torch.Tensor] + | list[np.ndarray] = None, + strength: float = 0.8, + prompt: str | list[str] = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.Tensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + if XLA_AVAILABLE: + xm.mark_step() + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..03a9d6f7c5e8584a0d57b3697bfe085a0b1640dd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -0,0 +1,1032 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + + feature_extractor: CLIPImageProcessor | None + safety_checker: IFSafetyChecker | None + + watermarker: IFWatermarker | None + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + safety_checker: IFSafetyChecker | None, + feature_extractor: CLIPImageProcessor | None, + watermarker: IFWatermarker | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.config.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PIL.Image.Image + | torch.Tensor + | np.ndarray + | list[PIL.Image.Image] + | list[torch.Tensor] + | list[np.ndarray] = None, + mask_image: PIL.Image.Image + | torch.Tensor + | np.ndarray + | list[PIL.Image.Image] + | list[torch.Tensor] + | list[np.ndarray] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 1.0): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. Prepare intermediate images + image = self.preprocess_image(image) + image = image.to(device=device, dtype=dtype) + + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = ( + torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images + ) + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + if XLA_AVAILABLE: + xm.mark_step() + + image = intermediate_images + + if output_type == "pil": + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + + # 11. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 8. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 9. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..841382ad9c6347f906fa4fee0d2fd7e26eb08720 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -0,0 +1,1140 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + PIL_INTERPOLATION, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize +def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: + w, h = images.size + + coef = w / h + + w, h = img_size, img_size + + if coef >= 1: + w = int(round(img_size / 8 * coef) * 8) + else: + h = int(round(img_size / 8 / coef) * 8) + + images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) + + return images + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + >>> from PIL import Image + >>> import requests + >>> from io import BytesIO + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" + >>> response = requests.get(url) + >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> original_image = original_image + + >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" + >>> response = requests.get(url) + >>> mask_image = Image.open(BytesIO(response.content)) + >>> mask_image = mask_image + + >>> pipe = IFInpaintingPipeline.from_pretrained( + ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "blue sunglasses" + + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + >>> image = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... output_type="pt", + ... ).images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, + ... mask_image=mask_image, + ... original_image=original_image, + ... prompt_embeds=prompt_embeds, + ... negative_prompt_embeds=negative_embeds, + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` + """ + + +class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: CLIPImageProcessor | None + safety_checker: IFSafetyChecker | None + + watermarker: IFWatermarker | None + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + model_cpu_offload_seq = "text_encoder->unet" + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: IFSafetyChecker | None, + feature_extractor: CLIPImageProcessor | None, + watermarker: IFWatermarker | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet is not None and unet.config.in_channels != 6: + logger.warning( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # image + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # original_image + + if isinstance(original_image, list): + check_image_type = original_image[0] + else: + check_image_type = original_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`original_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(original_image, list): + image_batch_size = len(original_image) + elif isinstance(original_image, torch.Tensor): + image_batch_size = original_image.shape[0] + elif isinstance(original_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(original_image, np.ndarray): + image_batch_size = original_image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError( + f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" + ) + + # mask_image + + if isinstance(mask_image, list): + check_image_type = mask_image[0] + else: + check_image_type = mask_image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`mask_image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(mask_image, list): + image_batch_size = len(mask_image) + elif isinstance(mask_image, torch.Tensor): + image_batch_size = mask_image.shape[0] + elif isinstance(mask_image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(mask_image, np.ndarray): + image_batch_size = mask_image.shape[0] + else: + assert False + + if image_batch_size != 1 and batch_size != image_batch_size: + raise ValueError( + f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image + def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: + if not isinstance(image, list): + image = [image] + + def numpy_to_pt(images): + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + if isinstance(image[0], PIL.Image.Image): + new_image = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = resize(image_, self.unet.config.sample_size) + image_ = np.array(image_) + image_ = image_.astype(np.float32) + image_ = image_ / 127.5 - 1 + new_image.append(image_) + + image = new_image + + image = np.stack(image, axis=0) # to np + image = numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = numpy_to_pt(image) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image + def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image + def preprocess_mask_image(self, mask_image) -> torch.Tensor: + if not isinstance(mask_image, list): + mask_image = [mask_image] + + if isinstance(mask_image[0], torch.Tensor): + mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) + + if mask_image.ndim == 2: + # Batch and add channel dim for single mask + mask_image = mask_image.unsqueeze(0).unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] == 1: + # Single mask, the 0'th dimension is considered to be + # the existing batch size of 1 + mask_image = mask_image.unsqueeze(0) + elif mask_image.ndim == 3 and mask_image.shape[0] != 1: + # Batch of mask, the 0'th dimension is considered to be + # the batching dimension + mask_image = mask_image.unsqueeze(1) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + + elif isinstance(mask_image[0], PIL.Image.Image): + new_mask_image = [] + + for mask_image_ in mask_image: + mask_image_ = mask_image_.convert("L") + mask_image_ = resize(mask_image_, self.unet.config.sample_size) + mask_image_ = np.array(mask_image_) + mask_image_ = mask_image_[None, None, :] + new_mask_image.append(mask_image_) + + mask_image = new_mask_image + + mask_image = np.concatenate(mask_image, axis=0) + mask_image = mask_image.astype(np.float32) / 255.0 + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + elif isinstance(mask_image[0], np.ndarray): + mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) + + mask_image[mask_image < 0.5] = 0 + mask_image[mask_image >= 0.5] = 1 + mask_image = torch.from_numpy(mask_image) + + return mask_image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images + def prepare_intermediate_images( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None + ): + image_batch_size, channels, height, width = image.shape + + batch_size = batch_size * num_images_per_prompt + + shape = (batch_size, channels, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + noised_image = self.scheduler.add_noise(image, noise, timestep) + + image = (1 - mask_image) * image + mask_image * noised_image + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image | np.ndarray | torch.Tensor, + original_image: PIL.Image.Image + | torch.Tensor + | np.ndarray + | list[PIL.Image.Image] + | list[torch.Tensor] + | list[np.ndarray] = None, + mask_image: PIL.Image.Image + | torch.Tensor + | np.ndarray + | list[PIL.Image.Image] + | list[torch.Tensor] + | list[np.ndarray] = None, + strength: float = 0.8, + prompt: str | list[str] = None, + num_inference_steps: int = 100, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + noise_level: int = 0, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + original_image (`torch.Tensor` or `PIL.Image.Image`): + The original image that `image` was varied from. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to 0): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + original_image, + mask_image, + batch_size, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + dtype = prompt_embeds.dtype + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + + # 5. prepare original image + original_image = self.preprocess_original_image(original_image) + original_image = original_image.to(device=device, dtype=dtype) + + # 6. prepare mask image + mask_image = self.preprocess_mask_image(mask_image) + mask_image = mask_image.to(device=device, dtype=dtype) + + if mask_image.shape[0] == 1: + mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) + else: + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + + # 6. Prepare intermediate images + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + intermediate_images = self.prepare_intermediate_images( + original_image, + noise_timestep, + batch_size, + num_images_per_prompt, + dtype, + device, + mask_image, + generator, + ) + + # 7. Prepare upscaled image and noise level + _, _, height, width = original_image.shape + + image = self.preprocess_image(image, num_images_per_prompt, device) + + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + prev_intermediate_images = intermediate_images + + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + if XLA_AVAILABLE: + xm.mark_step() + + image = intermediate_images + + if output_type == "pil": + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 12. Convert to PIL + image = self.numpy_to_pil(image) + + # 13. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + else: + # 10. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 11. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..52ebebb6f9b48ddbf1a57e17b09d3829f4e58f0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -0,0 +1,883 @@ +import html +import inspect +import re +import urllib.parse as ul +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import UNet2DConditionModel +from ...schedulers import DDPMScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import IFPipelineOutput +from .safety_checker import IFSafetyChecker +from .watermark import IFWatermarker + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline + >>> from diffusers.utils import pt_to_pil + >>> import torch + + >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' + >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) + + >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images + + >>> # save intermediate image + >>> pil_image = pt_to_pil(image) + >>> pil_image[0].save("./if_stage_I.png") + + >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( + ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> super_res_1_pipe.enable_model_cpu_offload() + + >>> image = super_res_1_pipe( + ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds + ... ).images + >>> image[0].save("./if_stage_II.png") + ``` +""" + + +class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + tokenizer: T5Tokenizer + text_encoder: T5EncoderModel + + unet: UNet2DConditionModel + scheduler: DDPMScheduler + image_noising_scheduler: DDPMScheduler + + feature_extractor: CLIPImageProcessor | None + safety_checker: IFSafetyChecker | None + + watermarker: IFWatermarker | None + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] + model_cpu_offload_seq = "text_encoder->unet" + _exclude_from_cpu_offload = ["watermarker"] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + image_noising_scheduler: DDPMScheduler, + safety_checker: IFSafetyChecker | None, + feature_extractor: CLIPImageProcessor | None, + watermarker: IFWatermarker | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the IF license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if unet is not None and unet.config.in_channels != 6: + logger.warning( + "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." + ) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + image_noising_scheduler=image_noising_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + watermarker=watermarker, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clean_caption: bool = False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF + max_length = 77 + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.unet is not None: + dtype = self.unet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" + ) + + if isinstance(image, list): + check_image_type = image[0] + else: + check_image_type = image + + if ( + not isinstance(check_image_type, torch.Tensor) + and not isinstance(check_image_type, PIL.Image.Image) + and not isinstance(check_image_type, np.ndarray) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list[...] but is" + f" {type(check_image_type)}" + ) + + if isinstance(image, list): + image_batch_size = len(image) + elif isinstance(image, torch.Tensor): + image_batch_size = image.shape[0] + elif isinstance(image, PIL.Image.Image): + image_batch_size = 1 + elif isinstance(image, np.ndarray): + image_batch_size = image.shape[0] + else: + assert False + + if batch_size != image_batch_size: + raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images + def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + intermediate_images = intermediate_images * self.scheduler.init_noise_sigma + return intermediate_images + + def preprocess_image(self, image, num_images_per_prompt, device): + if not isinstance(image, torch.Tensor) and not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] + + image = np.stack(image, axis=0) # to np + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image[0], np.ndarray): + image = np.stack(image, axis=0) # to np + if image.ndim == 5: + image = image[0] + + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) + elif isinstance(image, list) and isinstance(image[0], torch.Tensor): + dims = image[0].ndim + + if dims == 3: + image = torch.stack(image, dim=0) + elif dims == 4: + image = torch.concat(image, dim=0) + else: + raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") + + image = image.to(device=device, dtype=self.unet.dtype) + + image = image.repeat_interleave(num_images_per_prompt, dim=0) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int = None, + width: int = None, + image: PIL.Image.Image | np.ndarray | torch.Tensor = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + noise_level: int = 250, + clean_caption: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to None): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to None): + The width in pixels of the generated image. + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): + The image to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*, defaults to None): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to 250): + The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + or watermarked content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + self.check_inputs( + prompt, + image, + batch_size, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare intermediate images + num_channels = self.unet.config.in_channels // 2 + intermediate_images = self.prepare_intermediate_images( + batch_size * num_images_per_prompt, + num_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare upscaled image and noise level + image = self.preprocess_image(image, num_images_per_prompt, device) + upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) + + noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) + noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) + upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) + + if do_classifier_free_guidance: + noise_level = torch.cat([noise_level] * 2) + + # HACK: see comment in `enable_model_cpu_offload` + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_input = torch.cat([intermediate_images, upscaled], dim=1) + + model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + model_input = self.scheduler.scale_model_input(model_input, t) + + # predict the noise residual + noise_pred = self.unet( + model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + intermediate_images = self.scheduler.step( + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, intermediate_images) + + if XLA_AVAILABLE: + xm.mark_step() + + image = intermediate_images + + if output_type == "pil": + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 11. Convert to PIL + image = self.numpy_to_pil(image) + + # 12. Apply watermark + if self.watermarker is not None: + self.watermarker.apply_watermark(image, self.unet.config.sample_size) + elif output_type == "pt": + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + else: + # 9. Post-processing + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # 10. Run safety checker + image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, nsfw_detected, watermark_detected) + + return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe1cd013835704d4d068961aa1fa43000c45d10 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/pipeline_output.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class IFPipelineOutput(BaseOutput): + r""" + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`): + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_detected (`list[bool]`): + list of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content or a watermark. `None` if safety checking could not be performed. + watermark_detected (`list[bool]`): + list of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety + checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_detected: list[bool] | None + watermark_detected: list[bool] | None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/safety_checker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..8ffeed580bbea1514b11bf7a168a952328d8f424 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/safety_checker.py @@ -0,0 +1,59 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class IFSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModelWithProjection(config.vision_config) + + self.p_head = nn.Linear(config.vision_config.projection_dim, 1) + self.w_head = nn.Linear(config.vision_config.projection_dim, 1) + + @torch.no_grad() + def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): + image_embeds = self.vision_model(clip_input)[0] + + nsfw_detected = self.p_head(image_embeds) + nsfw_detected = nsfw_detected.flatten() + nsfw_detected = nsfw_detected > p_threshold + nsfw_detected = nsfw_detected.tolist() + + if any(nsfw_detected): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, nsfw_detected_ in enumerate(nsfw_detected): + if nsfw_detected_: + images[idx] = np.zeros(images[idx].shape) + + watermark_detected = self.w_head(image_embeds) + watermark_detected = watermark_detected.flatten() + watermark_detected = watermark_detected > w_threshold + watermark_detected = watermark_detected.tolist() + + if any(watermark_detected): + logger.warning( + "Potential watermarked content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + for idx, watermark_detected_ in enumerate(watermark_detected): + if watermark_detected_: + images[idx] = np.zeros(images[idx].shape) + + return images, nsfw_detected, watermark_detected diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/timesteps.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..d44285c017bbb2ccffa4ae86dd77792a048625d9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/timesteps.py @@ -0,0 +1,579 @@ +fast27_timesteps = [ + 999, + 800, + 799, + 600, + 599, + 500, + 400, + 399, + 377, + 355, + 333, + 311, + 288, + 266, + 244, + 222, + 200, + 199, + 177, + 155, + 133, + 111, + 88, + 66, + 44, + 22, + 0, +] + +smart27_timesteps = [ + 999, + 976, + 952, + 928, + 905, + 882, + 858, + 857, + 810, + 762, + 715, + 714, + 572, + 429, + 428, + 286, + 285, + 238, + 190, + 143, + 142, + 118, + 95, + 71, + 47, + 24, + 0, +] + +smart50_timesteps = [ + 999, + 988, + 977, + 966, + 955, + 944, + 933, + 922, + 911, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 350, + 300, + 299, + 266, + 233, + 200, + 199, + 179, + 159, + 140, + 120, + 100, + 99, + 88, + 77, + 66, + 55, + 44, + 33, + 22, + 11, + 0, +] + +smart100_timesteps = [ + 999, + 995, + 992, + 989, + 985, + 981, + 978, + 975, + 971, + 967, + 964, + 961, + 957, + 956, + 951, + 947, + 942, + 937, + 933, + 928, + 923, + 919, + 914, + 913, + 908, + 903, + 897, + 892, + 887, + 881, + 876, + 871, + 870, + 864, + 858, + 852, + 846, + 840, + 834, + 828, + 827, + 820, + 813, + 806, + 799, + 792, + 785, + 784, + 777, + 770, + 763, + 756, + 749, + 742, + 741, + 733, + 724, + 716, + 707, + 699, + 698, + 688, + 677, + 666, + 656, + 655, + 645, + 634, + 623, + 613, + 612, + 598, + 584, + 570, + 569, + 555, + 541, + 527, + 526, + 505, + 484, + 483, + 462, + 440, + 439, + 396, + 395, + 352, + 351, + 308, + 307, + 264, + 263, + 220, + 219, + 176, + 132, + 88, + 44, + 0, +] + +smart185_timesteps = [ + 999, + 997, + 995, + 992, + 990, + 988, + 986, + 984, + 981, + 979, + 977, + 975, + 972, + 970, + 968, + 966, + 964, + 961, + 959, + 957, + 956, + 954, + 951, + 949, + 946, + 944, + 941, + 939, + 936, + 934, + 931, + 929, + 926, + 924, + 921, + 919, + 916, + 914, + 913, + 910, + 907, + 905, + 902, + 899, + 896, + 893, + 891, + 888, + 885, + 882, + 879, + 877, + 874, + 871, + 870, + 867, + 864, + 861, + 858, + 855, + 852, + 849, + 846, + 843, + 840, + 837, + 834, + 831, + 828, + 827, + 824, + 821, + 817, + 814, + 811, + 808, + 804, + 801, + 798, + 795, + 791, + 788, + 785, + 784, + 780, + 777, + 774, + 770, + 766, + 763, + 760, + 756, + 752, + 749, + 746, + 742, + 741, + 737, + 733, + 730, + 726, + 722, + 718, + 714, + 710, + 707, + 703, + 699, + 698, + 694, + 690, + 685, + 681, + 677, + 673, + 669, + 664, + 660, + 656, + 655, + 650, + 646, + 641, + 636, + 632, + 627, + 622, + 618, + 613, + 612, + 607, + 602, + 596, + 591, + 586, + 580, + 575, + 570, + 569, + 563, + 557, + 551, + 545, + 539, + 533, + 527, + 526, + 519, + 512, + 505, + 498, + 491, + 484, + 483, + 474, + 466, + 457, + 449, + 440, + 439, + 428, + 418, + 407, + 396, + 395, + 381, + 366, + 352, + 351, + 330, + 308, + 307, + 286, + 264, + 263, + 242, + 220, + 219, + 176, + 175, + 132, + 131, + 88, + 44, + 0, +] + +super27_timesteps = [ + 999, + 991, + 982, + 974, + 966, + 958, + 950, + 941, + 933, + 925, + 916, + 908, + 900, + 899, + 874, + 850, + 825, + 800, + 799, + 700, + 600, + 500, + 400, + 300, + 200, + 100, + 0, +] + +super40_timesteps = [ + 999, + 992, + 985, + 978, + 971, + 964, + 957, + 949, + 942, + 935, + 928, + 921, + 914, + 907, + 900, + 899, + 879, + 859, + 840, + 820, + 800, + 799, + 766, + 733, + 700, + 699, + 650, + 600, + 599, + 500, + 499, + 400, + 399, + 300, + 299, + 200, + 199, + 100, + 99, + 0, +] + +super100_timesteps = [ + 999, + 996, + 992, + 989, + 985, + 982, + 979, + 975, + 972, + 968, + 965, + 961, + 958, + 955, + 951, + 948, + 944, + 941, + 938, + 934, + 931, + 927, + 924, + 920, + 917, + 914, + 910, + 907, + 903, + 900, + 899, + 891, + 884, + 876, + 869, + 861, + 853, + 846, + 838, + 830, + 823, + 815, + 808, + 800, + 799, + 788, + 777, + 766, + 755, + 744, + 733, + 722, + 711, + 700, + 699, + 688, + 677, + 666, + 655, + 644, + 633, + 622, + 611, + 600, + 599, + 585, + 571, + 557, + 542, + 528, + 514, + 500, + 499, + 485, + 471, + 457, + 442, + 428, + 414, + 400, + 399, + 379, + 359, + 340, + 320, + 300, + 299, + 279, + 259, + 240, + 220, + 200, + 199, + 166, + 133, + 100, + 99, + 66, + 33, + 0, +] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/watermark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fe99f681f7351630af885b9c01db4dce1971fe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deepfloyd_if/watermark.py @@ -0,0 +1,44 @@ +import PIL.Image +import torch +from PIL import Image + +from ...configuration_utils import ConfigMixin +from ...models.modeling_utils import ModelMixin +from ...utils import PIL_INTERPOLATION + + +class IFWatermarker(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) + self.watermark_image_as_pil = None + + def apply_watermark(self, images: list[PIL.Image.Image], sample_size=None): + # Copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 + + h = images[0].height + w = images[0].width + + sample_size = sample_size or h + + coef = min(h / sample_size, w / sample_size) + img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) + + S1, S2 = 1024**2, img_w * img_h + K = (S2 / S1) ** 0.5 + wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) + + if self.watermark_image_as_pil is None: + watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() + watermark_image = Image.fromarray(watermark_image, mode="RGBA") + self.watermark_image_as_pil = watermark_image + + wm_img = self.watermark_image_as_pil.resize( + (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None + ) + + for pil_img in images: + pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) + + return images diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9936323170adbceac2c5c25e3881ea731d8602e1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/__init__.py @@ -0,0 +1,153 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_librosa_available, + is_note_seq_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_pt_objects + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] + _import_structure["pndm"] = ["PNDMPipeline"] + _import_structure["repaint"] = ["RePaintPipeline"] + _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] + _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["alt_diffusion"] = [ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AltDiffusionPipelineOutput", + ] + _import_structure["versatile_diffusion"] = [ + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] + _import_structure["stable_diffusion_variants"] = [ + "CycleDiffusionPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionModelEditingPipeline", + ] + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) + +else: + _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] + +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) + +else: + _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_pt_objects import * + + else: + from .latent_diffusion_uncond import LDMPipeline + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AltDiffusionPipelineOutput + from .audio_diffusion import AudioDiffusionPipeline, Mel + from .spectrogram_diffusion import SpectrogramDiffusionPipeline + from .stable_diffusion_variants import ( + CycleDiffusionPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionModelEditingPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPix2PixZeroPipeline, + ) + from .stochastic_karras_ve import KarrasVePipeline + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_librosa_objects import * + else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + else: + from .spectrogram_diffusion import ( + MidiProcessor, + SpectrogramDiffusionPipeline, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2020f827c1d494b87c8d781b62c61a51587eb7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71fa15b3feff08dc4008d1fa02ba61ad1300efed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_roberta_series"] = ["RobertaSeriesModelWithTransformation"] + _import_structure["pipeline_alt_diffusion"] = ["AltDiffusionPipeline"] + _import_structure["pipeline_alt_diffusion_img2img"] = ["AltDiffusionImg2ImgPipeline"] + + _import_structure["pipeline_output"] = ["AltDiffusionPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_roberta_series import RobertaSeriesModelWithTransformation + from .pipeline_alt_diffusion import AltDiffusionPipeline + from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline + from .pipeline_output import AltDiffusionPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88a2adf257b37570be3a8b9975eaf6cfb8aeeca2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/modeling_roberta_series.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/modeling_roberta_series.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..954c2481d3705e237b8112dd73c02616f760c233 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/modeling_roberta_series.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c8f990446f26fca0d13f8e34cb0086386787b9b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..443e413723f2db3ff3bc5ec0ff918d1cb226ac6d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_alt_diffusion_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c4deab7a9491aa79cd80da1fe6870634e18e408 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py new file mode 100644 index 0000000000000000000000000000000000000000..ed72e505b9c34ce373ce8a13beaea528ede2a09e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py @@ -0,0 +1,123 @@ +from dataclasses import dataclass + +import torch +from torch import nn +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one + for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: torch.Tensor | None = None + last_hidden_state: torch.Tensor = None + hidden_states: tuple[torch.Tensor] | None = None + attentions: tuple[torch.Tensor] | None = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.has_pre_transformation = getattr(config, "has_pre_transformation", False) + if self.has_pre_transformation: + self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) + self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_init() + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + return_dict: bool | None = None, + output_hidden_states: bool | None = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True if self.has_pre_transformation else output_hidden_states, + return_dict=return_dict, + ) + + if self.has_pre_transformation: + sequence_output2 = outputs["hidden_states"][-2] + sequence_output2 = self.pre_LN(sequence_output2) + projection_state2 = self.transformation_pre(sequence_output2) + + return TransformationModelOutput( + projection_state=projection_state2, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + else: + projection_state = self.transformation(outputs.last_hidden_state) + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9fab42916e9e61176544ac868e46ac2d5f72c1fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -0,0 +1,990 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .modeling_roberta_series import RobertaSeriesModelWithTransformation +from .pipeline_output import AltDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AltDiffusionPipeline + + >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" + >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AltDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.XLMRobertaTokenizer`]): + A `XLMRobertaTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..b6cd51c6d2030369fb4a0e23dbad58839cb3e4e4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -0,0 +1,1045 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ....models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .modeling_roberta_series import RobertaSeriesModelWithTransformation +from .pipeline_output import AltDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import AltDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "BAAI/AltDiffusion-m9" + >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> # "A fantasy landscape, trending on artstation" + >>> prompt = "幻想风景, artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("幻想风景.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AltDiffusionImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.RobertaSeriesModelWithTransformation`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.XLMRobertaTokenizer`]): + A `XLMRobertaTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..259a8675696551b4010396aadec14a66a00a372d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/alt_diffusion/pipeline_output.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ....utils import ( + BaseOutput, +) + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`list[bool]`) + list indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3127951863a7db3f9dd8e42ac5ab64fa9ac3ec0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__init__.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = { + "mel": ["Mel"], + "pipeline_audio_diffusion": ["AudioDiffusionPipeline"], +} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .mel import Mel + from .pipeline_audio_diffusion import AudioDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b93f2979a5e6ca0193fc35a65cb2a6548e05ff4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/mel.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/mel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d40b0317be24d16a6446f5e842ff63a5a0c1bea6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/mel.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5569d37ebef2ae93fe24d7ad5a48a4ddd28ee87b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/__pycache__/pipeline_audio_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/mel.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/mel.py new file mode 100644 index 0000000000000000000000000000000000000000..0902f993a0600aa08a39a638f4a75bea5d301415 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/mel.py @@ -0,0 +1,179 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np # noqa: E402 + +from ....configuration_utils import ConfigMixin, register_to_config +from ....schedulers.scheduling_utils import SchedulerMixin + + +try: + import librosa # noqa: E402 + + _librosa_can_be_imported = True + _import_error = "" +except Exception as e: + _librosa_can_be_imported = False + _import_error = ( + f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." + ) + + +from PIL import Image # noqa: E402 + + +class Mel(ConfigMixin, SchedulerMixin): + """ + Parameters: + x_res (`int`): + x resolution of spectrogram (time). + y_res (`int`): + y resolution of spectrogram (frequency bins). + sample_rate (`int`): + Sample rate of audio. + n_fft (`int`): + Number of Fast Fourier Transforms. + hop_length (`int`): + Hop length (a higher number is recommended if `y_res` < 256). + top_db (`int`): + Loudest decibel value. + n_iter (`int`): + Number of iterations for Griffin-Lim Mel inversion. + """ + + config_name = "mel_config.json" + + @register_to_config + def __init__( + self, + x_res: int = 256, + y_res: int = 256, + sample_rate: int = 22050, + n_fft: int = 2048, + hop_length: int = 512, + top_db: int = 80, + n_iter: int = 32, + ): + self.hop_length = hop_length + self.sr = sample_rate + self.n_fft = n_fft + self.top_db = top_db + self.n_iter = n_iter + self.set_resolution(x_res, y_res) + self.audio = None + + if not _librosa_can_be_imported: + raise ValueError(_import_error) + + def set_resolution(self, x_res: int, y_res: int): + """Set resolution. + + Args: + x_res (`int`): + x resolution of spectrogram (time). + y_res (`int`): + y resolution of spectrogram (frequency bins). + """ + self.x_res = x_res + self.y_res = y_res + self.n_mels = self.y_res + self.slice_size = self.x_res * self.hop_length - 1 + + def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): + """Load audio. + + Args: + audio_file (`str`): + An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. + raw_audio (`np.ndarray`): + The raw audio file as a NumPy array. + """ + if audio_file is not None: + self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) + else: + self.audio = raw_audio + + # Pad with silence if necessary. + if len(self.audio) < self.x_res * self.hop_length: + self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) + + def get_number_of_slices(self) -> int: + """Get number of slices in audio. + + Returns: + `int`: + Number of spectograms audio can be sliced into. + """ + return len(self.audio) // self.slice_size + + def get_audio_slice(self, slice: int = 0) -> np.ndarray: + """Get slice of audio. + + Args: + slice (`int`): + Slice number of audio (out of `get_number_of_slices()`). + + Returns: + `np.ndarray`: + The audio slice as a NumPy array. + """ + return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] + + def get_sample_rate(self) -> int: + """Get sample rate. + + Returns: + `int`: + Sample rate of audio. + """ + return self.sr + + def audio_slice_to_image(self, slice: int) -> Image.Image: + """Convert slice of audio to spectrogram. + + Args: + slice (`int`): + Slice number of audio to convert (out of `get_number_of_slices()`). + + Returns: + `PIL Image`: + A grayscale image of `x_res x y_res`. + """ + S = librosa.feature.melspectrogram( + y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels + ) + log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) + bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) + image = Image.fromarray(bytedata) + return image + + def image_to_audio(self, image: Image.Image) -> np.ndarray: + """Converts spectrogram to audio. + + Args: + image (`PIL Image`): + An grayscale image of `x_res x y_res`. + + Returns: + audio (`np.ndarray`): + The audio as a NumPy array. + """ + bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) + log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db + S = librosa.db_to_power(log_S) + audio = librosa.feature.inverse.mel_to_audio( + S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter + ) + return audio diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f63fc8aacbc8b29f6cbebb87154f918260933ea8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py @@ -0,0 +1,325 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from math import acos, sin + +import numpy as np +import torch +from PIL import Image + +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import DDIMScheduler, DDPMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput +from .mel import Mel + + +class AudioDiffusionPipeline(DiffusionPipeline): + """ + Pipeline for audio diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + mel ([`Mel`]): + Transform audio into a spectrogram. + scheduler ([`DDIMScheduler`] or [`DDPMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`] or [`DDPMScheduler`]. + """ + + _optional_components = ["vqvae"] + + def __init__( + self, + vqvae: AutoencoderKL, + unet: UNet2DConditionModel, + mel: Mel, + scheduler: DDIMScheduler | DDPMScheduler, + ): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) + + def get_default_steps(self) -> int: + """Returns default number of steps recommended for inference. + + Returns: + `int`: + The number of steps. + """ + return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + audio_file: str = None, + raw_audio: np.ndarray = None, + slice: int = 0, + start_step: int = 0, + steps: int = None, + generator: torch.Generator = None, + mask_start_secs: float = 0, + mask_end_secs: float = 0, + step_generator: torch.Generator = None, + eta: float = 0, + noise: torch.Tensor = None, + encoding: torch.Tensor = None, + return_dict=True, + ) -> AudioPipelineOutput | ImagePipelineOutput | tuple[list[Image.Image], tuple[int, list[np.ndarray]]]: + """ + The call function to the pipeline for generation. + + Args: + batch_size (`int`): + Number of samples to generate. + audio_file (`str`): + An audio file that must be on disk due to [Librosa](https://librosa.org/) limitation. + raw_audio (`np.ndarray`): + The raw audio file as a NumPy array. + slice (`int`): + Slice number of audio to convert. + start_step (int): + Step to start diffusion from. + steps (`int`): + Number of denoising steps (defaults to `50` for DDIM and `1000` for DDPM). + generator (`torch.Generator`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + mask_start_secs (`float`): + Number of seconds of audio to mask (not generate) at start. + mask_end_secs (`float`): + Number of seconds of audio to mask (not generate) at end. + step_generator (`torch.Generator`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) used to denoise. + None + eta (`float`): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + noise (`torch.Tensor`): + A noise tensor of shape `(batch_size, 1, height, width)` or `None`. + encoding (`torch.Tensor`): + A tensor for [`UNet2DConditionModel`] of shape `(batch_size, seq_length, cross_attention_dim)`. + return_dict (`bool`): + Whether or not to return a [`AudioPipelineOutput`], [`ImagePipelineOutput`] or a plain tuple. + + Examples: + + For audio diffusion: + + ```py + import torch + from IPython.display import Audio + from diffusers import DiffusionPipeline + + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to(device) + + output = pipe() + display(output.images[0]) + display(Audio(output.audios[0], rate=mel.get_sample_rate())) + ``` + + For latent audio diffusion: + + ```py + import torch + from IPython.display import Audio + from diffusers import DiffusionPipeline + + device = "cuda" if torch.cuda.is_available() else "cpu" + pipe = DiffusionPipeline.from_pretrained("teticio/latent-audio-diffusion-256").to(device) + + output = pipe() + display(output.images[0]) + display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) + ``` + + For other tasks like variation, inpainting, outpainting, etc: + + ```py + output = pipe( + raw_audio=output.audios[0, 0], + start_step=int(pipe.get_default_steps() / 2), + mask_start_secs=1, + mask_end_secs=1, + ) + display(output.images[0]) + display(Audio(output.audios[0], rate=pipe.mel.get_sample_rate())) + ``` + + Returns: + `list[PIL Image]`: + A list of Mel spectrograms (`float`, `list[np.ndarray]`) with the sample rate and raw audio. + """ + + steps = steps or self.get_default_steps() + self.scheduler.set_timesteps(steps) + step_generator = step_generator or generator + # For backwards compatibility + if isinstance(self.unet.config.sample_size, int): + self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) + if noise is None: + noise = randn_tensor( + ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size[0], + self.unet.config.sample_size[1], + ), + generator=generator, + device=self.device, + ) + images = noise + mask = None + + if audio_file is not None or raw_audio is not None: + self.mel.load_audio(audio_file, raw_audio) + input_image = self.mel.audio_slice_to_image(slice) + input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( + (input_image.height, input_image.width) + ) + input_image = (input_image / 255) * 2 - 1 + input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) + + if self.vqvae is not None: + input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( + generator=generator + )[0] + input_images = self.vqvae.config.scaling_factor * input_images + + if start_step > 0: + images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) + + pixels_per_second = ( + self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length + ) + mask_start = int(mask_start_secs * pixels_per_second) + mask_end = int(mask_end_secs * pixels_per_second) + mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) + + for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): + if isinstance(self.unet, UNet2DConditionModel): + model_output = self.unet(images, t, encoding)["sample"] + else: + model_output = self.unet(images, t)["sample"] + + if isinstance(self.scheduler, DDIMScheduler): + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + eta=eta, + generator=step_generator, + )["prev_sample"] + else: + images = self.scheduler.step( + model_output=model_output, + timestep=t, + sample=images, + generator=step_generator, + )["prev_sample"] + + if mask is not None: + if mask_start > 0: + images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] + if mask_end > 0: + images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] + + if self.vqvae is not None: + # 0.18215 was scaling factor used in training to ensure unit variance + images = 1 / self.vqvae.config.scaling_factor * images + images = self.vqvae.decode(images)["sample"] + + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).numpy() + images = (images * 255).round().astype("uint8") + images = list( + (Image.fromarray(_[:, :, 0]) for _ in images) + if images.shape[3] == 1 + else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) + ) + + audios = [self.mel.image_to_audio(_) for _ in images] + if not return_dict: + return images, (self.mel.get_sample_rate(), audios) + + return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) + + @torch.no_grad() + def encode(self, images: list[Image.Image], steps: int = 50) -> np.ndarray: + """ + Reverse the denoising step process to recover a noisy image from the generated image. + + Args: + images (`list[PIL Image]`): + list of images to encode. + steps (`int`): + Number of encoding steps to perform (defaults to `50`). + + Returns: + `np.ndarray`: + A noise tensor of shape `(batch_size, 1, height, width)`. + """ + + # Only works with DDIM as this method is deterministic + assert isinstance(self.scheduler, DDIMScheduler) + self.scheduler.set_timesteps(steps) + sample = np.array( + [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] + ) + sample = (sample / 255) * 2 - 1 + sample = torch.Tensor(sample).to(self.device) + + for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.scheduler.final_alpha_cumprod + ) + beta_prod_t = 1 - alpha_prod_t + model_output = self.unet(sample, t)["sample"] + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output + sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) + sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output + + return sample + + @staticmethod + def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: + """Spherical Linear intERPolation. + + Args: + x0 (`torch.Tensor`): + The first tensor to interpolate between. + x1 (`torch.Tensor`): + Second tensor to interpolate between. + alpha (`float`): + Interpolation between 0 and 1 + + Returns: + `torch.Tensor`: + The interpolated tensor. + """ + + theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) + return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..214f5bbca969f9ae0629578c72aaf339f86ded88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_latent_diffusion_uncond": ["LDMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_latent_diffusion_uncond import LDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08f2d3f5996d957f2ea28980d4fecd4bb02d502f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e07fcfe52078676204baa85185a8172a5f9a86c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/__pycache__/pipeline_latent_diffusion_uncond.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py new file mode 100644 index 0000000000000000000000000000000000000000..70a65e2ef5be290f1443205d1239956908b54670 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -0,0 +1,129 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ....models import UNet2DModel, VQModel +from ....schedulers import DDIMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class LDMPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + [`DDIMScheduler`] is used in combination with `unet` to denoise the encoded image latents. + """ + + def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + eta: float = 0.0, + num_inference_steps: int = 50, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import LDMPipeline + + >>> # load model and scheduler + >>> pipe = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + latents = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + ) + latents = latents.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + latent_model_input = self.scheduler.scale_model_input(latents, t) + # predict the noise residual + noise_prediction = self.unet(latent_model_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample + + # adjust latents with inverse of vae scale + latents = latents / self.vqvae.config.scaling_factor + # decode the image latents with the VAE + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3bdba74079d77576655e22b43014a0438a9c2e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__init__.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_pndm": ["PNDMPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_pndm import PNDMPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af53f9fc4ea5d00e6ee1a84a221cdbb15f11efb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__pycache__/pipeline_pndm.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__pycache__/pipeline_pndm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54f212d4d9d8b1d7597fc420e35cfef601fbf030 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/__pycache__/pipeline_pndm.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py new file mode 100644 index 0000000000000000000000000000000000000000..fb116511f7275750a11d2c8cf791bcedeffaff82 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py @@ -0,0 +1,119 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ....models import UNet2DModel +from ....schedulers import PNDMScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class PNDMPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`PNDMScheduler`]): + A `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): + super().__init__() + + scheduler = PNDMScheduler.from_config(scheduler.config) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, `optional`, defaults to 1): + The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import PNDMPipeline + + >>> # load model and scheduler + >>> pndm = PNDMPipeline.from_pretrained("google/ddpm-cifar10-32") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pndm().images[0] + + >>> # save image + >>> image.save("pndm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # For more information on the sampling method you can take a look at Algorithm 2 of + # the official paper: https://huggingface.co/papers/2202.09778 + + # Sample gaussian noise to begin loop + image = randn_tensor( + (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), + generator=generator, + device=self.device, + ) + + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + model_output = self.unet(image, t).sample + + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6b04af52d40e8a2bfa2aa5812b9fb8b1da06f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_repaint": ["RePaintPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_repaint import RePaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b93d45c6569b2afebfb611ecba631b17181b5d8c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__pycache__/pipeline_repaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__pycache__/pipeline_repaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30495010db88b49f3bad30ff43c7cedc0fbf8fae Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/__pycache__/pipeline_repaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py new file mode 100644 index 0000000000000000000000000000000000000000..3231d5e130497fd5dd091291be4a164e7f76ac6b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py @@ -0,0 +1,229 @@ +# Copyright 2025 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import PIL.Image +import torch + +from ....models import UNet2DModel +from ....schedulers import RePaintScheduler +from ....utils import PIL_INTERPOLATION, deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def _preprocess_image(image: list | PIL.Image.Image | torch.Tensor): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def _preprocess_mask(mask: list | PIL.Image.Image | torch.Tensor): + if isinstance(mask, torch.Tensor): + return mask + elif isinstance(mask, PIL.Image.Image): + mask = [mask] + + if isinstance(mask[0], PIL.Image.Image): + w, h = mask[0].size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] + mask = np.concatenate(mask, axis=0) + mask = mask.astype(np.float32) / 255.0 + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.cat(mask, dim=0) + return mask + + +class RePaintPipeline(DiffusionPipeline): + r""" + Pipeline for image inpainting using RePaint. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`RePaintScheduler`]): + A `RePaintScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: RePaintScheduler + model_cpu_offload_seq = "unet" + + def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: torch.Tensor | PIL.Image.Image, + mask_image: torch.Tensor | PIL.Image.Image, + num_inference_steps: int = 250, + eta: float = 0.0, + jump_length: int = 10, + jump_n_sample: int = 10, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + The original image to inpaint on. + mask_image (`torch.Tensor` or `PIL.Image.Image`): + The mask_image where 0.0 define which part of the original image to inpaint. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`): + The weight of the added noise in a diffusion step. Its value is between 0.0 and 1.0; 0.0 corresponds to + DDIM and 1.0 is the DDPM scheduler. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in the + [paper](https://huggingface.co/papers/2201.09865). + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9 + and 10 in the [paper](https://huggingface.co/papers/2201.09865). + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from io import BytesIO + >>> import torch + >>> import PIL + >>> import requests + >>> from diffusers import RePaintPipeline, RePaintScheduler + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png" + >>> mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png" + + >>> # Load the original image and the mask as PIL images + >>> original_image = download_image(img_url).resize((256, 256)) + >>> mask_image = download_image(mask_url).resize((256, 256)) + + >>> # Load the RePaint scheduler and pipeline based on a pretrained DDPM model + >>> scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") + >>> pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> output = pipe( + ... image=original_image, + ... mask_image=mask_image, + ... num_inference_steps=250, + ... eta=0.0, + ... jump_length=10, + ... jump_n_sample=10, + ... generator=generator, + ... ) + >>> inpainted_image = output.images[0] + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + original_image = image + + original_image = _preprocess_image(original_image) + original_image = original_image.to(device=self._execution_device, dtype=self.unet.dtype) + mask_image = _preprocess_mask(mask_image) + mask_image = mask_image.to(device=self._execution_device, dtype=self.unet.dtype) + + batch_size = original_image.shape[0] + + # sample gaussian noise to begin the loop + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image_shape = original_image.shape + image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) + + # set step values + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self._execution_device) + self.scheduler.eta = eta + + t_last = self.scheduler.timesteps[0] + 1 + generator = generator[0] if isinstance(generator, list) else generator + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + if t < t_last: + # predict the noise residual + model_output = self.unet(image, t).sample + # compute previous image: x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample + + else: + # compute the reverse: x_t-1 -> x_t + image = self.scheduler.undo_step(image, t_last, generator) + t_last = t + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87c167c3dbd26e0408a41ef197a42dc5eb7038d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_score_sde_ve": ["ScoreSdeVePipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_score_sde_ve import ScoreSdeVePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62bd36d8ea507bc0f6f43979cd30af84a106bf68 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..857411c2004de31be0eb06d056958f73b6776864 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/__pycache__/pipeline_score_sde_ve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..c6abdba42d3ca23ce154f96fe4a7baa052d6203f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py @@ -0,0 +1,107 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ....models import UNet2DModel +from ....schedulers import ScoreSdeVeScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class ScoreSdeVePipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`ScoreSdeVeScheduler`]): + A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image. + """ + + unet: UNet2DModel + scheduler: ScoreSdeVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 2000, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, `optional`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, `optional`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma + sample = sample.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) + + # correction step + for _ in range(self.scheduler.config.correct_steps): + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample + + # prediction step + model_output = model(sample, sigma_t).sample + output = self.scheduler.step_pred(model_output, t, sample, generator=generator) + + sample, sample_mean = output.prev_sample, output.prev_sample_mean + + sample = sample_mean.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..150954baa0eb8f8a7216b4891effc14a71e21b1b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__init__.py @@ -0,0 +1,75 @@ +# flake8: noqa +from typing import TYPE_CHECKING +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + _LazyModule, + is_note_seq_available, + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + get_objects_from_module, +) + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["continous_encoder"] = ["SpectrogramContEncoder"] + _import_structure["notes_encoder"] = ["SpectrogramNotesEncoder"] + _import_structure["pipeline_spectrogram_diffusion"] = [ + "SpectrogramContEncoder", + "SpectrogramDiffusionPipeline", + "T5FilmDecoder", + ] +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_transformers_and_torch_and_note_seq_objects + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["midi_utils"] = ["MidiProcessor"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_spectrogram_diffusion import SpectrogramDiffusionPipeline + from .pipeline_spectrogram_diffusion import SpectrogramContEncoder + from .pipeline_spectrogram_diffusion import SpectrogramNotesEncoder + from .pipeline_spectrogram_diffusion import T5FilmDecoder + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_transformers_and_torch_and_note_seq_objects import * + + else: + from .midi_utils import MidiProcessor + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c0b949e4da917e933b191ae53cb4f7532ed9839 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/continuous_encoder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/continuous_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42394330aa3de51c8ab23cc51a44a26da4324f88 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/continuous_encoder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/midi_utils.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/midi_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cb38a9f85e1de44d33058bd0264fe1caae4d3a7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/midi_utils.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/notes_encoder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/notes_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08a99dd7b4cc475e2aea2b8d92f5c70a0c979d6e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/notes_encoder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/pipeline_spectrogram_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/pipeline_spectrogram_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca67c90fcd7ac2bbd55a7813cc56a520d1218ebe Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/__pycache__/pipeline_spectrogram_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b26e84f728693b29f62cc058c43810fc4983619b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py @@ -0,0 +1,92 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import ( + T5Block, + T5Config, + T5LayerNorm, +) + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin + + +class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + input_dims: int, + targets_context_length: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.input_proj = nn.Linear(input_dims, d_model, bias=False) + + self.position_encoding = nn.Embedding(targets_context_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + feed_forward_proj=feed_forward_proj, + dropout_rate=dropout_rate, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_inputs, encoder_inputs_mask): + x = self.input_proj(encoder_inputs) + + # terminal relative positional encodings + max_positions = encoder_inputs.shape[1] + input_positions = torch.arange(max_positions, device=encoder_inputs.device) + + seq_lens = encoder_inputs_mask.sum(-1) + input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) + x += self.position_encoding(input_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_inputs.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76b8576468d2c4502f4970d441292deda3c60965 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py @@ -0,0 +1,667 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import math +import os +from typing import Any, Callable, Mapping, MutableMapping, Sequence + +import numpy as np +import torch +import torch.nn.functional as F + +from ....utils import is_note_seq_available +from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH + + +if is_note_seq_available(): + import note_seq +else: + raise ImportError("Please install note-seq via `pip install note-seq`") + + +INPUT_FEATURE_LENGTH = 2048 + +SAMPLE_RATE = 16000 +HOP_SIZE = 320 +FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) + +DEFAULT_STEPS_PER_SECOND = 100 +DEFAULT_MAX_SHIFT_SECONDS = 10 +DEFAULT_NUM_VELOCITY_BINS = 1 + +SLAKH_CLASS_PROGRAMS = { + "Acoustic Piano": 0, + "Electric Piano": 4, + "Chromatic Percussion": 8, + "Organ": 16, + "Acoustic Guitar": 24, + "Clean Electric Guitar": 26, + "Distorted Electric Guitar": 29, + "Acoustic Bass": 32, + "Electric Bass": 33, + "Violin": 40, + "Viola": 41, + "Cello": 42, + "Contrabass": 43, + "Orchestral Harp": 46, + "Timpani": 47, + "String Ensemble": 48, + "Synth Strings": 50, + "Choir and Voice": 52, + "Orchestral Hit": 55, + "Trumpet": 56, + "Trombone": 57, + "Tuba": 58, + "French Horn": 60, + "Brass Section": 61, + "Soprano/Alto Sax": 64, + "Tenor Sax": 66, + "Baritone Sax": 67, + "Oboe": 68, + "English Horn": 69, + "Bassoon": 70, + "Clarinet": 71, + "Pipe": 73, + "Synth Lead": 80, + "Synth Pad": 88, +} + + +@dataclasses.dataclass +class NoteRepresentationConfig: + """Configuration note representations.""" + + onsets_only: bool + include_ties: bool + + +@dataclasses.dataclass +class NoteEventData: + pitch: int + velocity: int | None = None + program: int | None = None + is_drum: bool | None = None + instrument: int | None = None + + +@dataclasses.dataclass +class NoteEncodingState: + """Encoding state for note transcription, keeping track of active pitches.""" + + # velocity bin for active pitches and programs + active_pitches: MutableMapping[tuple[int, int], int] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclasses.dataclass +class Event: + type: str + value: int + + +class Tokenizer: + def __init__(self, regular_ids: int): + # The special tokens: 0=PAD, 1=EOS, and 2=UNK + self._num_special_tokens = 3 + self._num_regular_tokens = regular_ids + + def encode(self, token_ids): + encoded = [] + for token_id in token_ids: + if not 0 <= token_id < self._num_regular_tokens: + raise ValueError( + f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" + ) + encoded.append(token_id + self._num_special_tokens) + + # Add EOS token + encoded.append(1) + + # Pad to till INPUT_FEATURE_LENGTH + encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) + + return encoded + + +class Codec: + """Encode and decode events. + + Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from + Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not + include things like EOS or UNK token handling. + + To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required + and specified separately. + """ + + def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: list[EventRange]): + """Define Codec. + + Args: + max_shift_steps: Maximum number of shift steps that can be encoded. + steps_per_second: Shift steps will be interpreted as having a duration of + 1 / steps_per_second. + event_ranges: Other supported event types and their ranges. + """ + self.steps_per_second = steps_per_second + self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) + self._event_ranges = [self._shift_range] + event_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) + + @property + def num_classes(self) -> int: + return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + + # The next couple methods are simplified special case methods just for shift + # events that are intended to be used from within autograph functions. + + def is_shift_event_index(self, index: int) -> bool: + return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) + + @property + def max_shift_steps(self) -> int: + return self._shift_range.max_value + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + offset = 0 + for er in self._event_ranges: + if event.type == er.type: + if not er.min_value <= event.value <= er.max_value: + raise ValueError( + f"Event value {event.value} is not within valid range " + f"[{er.min_value}, {er.max_value}] for type {event.type}" + ) + return offset + event.value - er.min_value + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event.type}") + + def event_type_range(self, event_type: str) -> tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + offset = 0 + for er in self._event_ranges: + if event_type == er.type: + return offset, offset + (er.max_value - er.min_value) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event type: {event_type}") + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + offset = 0 + for er in self._event_ranges: + if offset <= index <= offset + er.max_value - er.min_value: + return Event(type=er.type, value=er.min_value + index - offset) + offset += er.max_value - er.min_value + 1 + + raise ValueError(f"Unknown event index: {index}") + + +@dataclasses.dataclass +class ProgramGranularity: + # both tokens_map_fn and program_map_fn should be idempotent + tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] + program_map_fn: Callable[[int], int] + + +def drop_programs(tokens, codec: Codec): + """Drops program change events from a token sequence.""" + min_program_id, max_program_id = codec.event_type_range("program") + return tokens[(tokens < min_program_id) | (tokens > max_program_id)] + + +def programs_to_midi_classes(tokens, codec): + """Modifies program events to be the first program in the MIDI class.""" + min_program_id, max_program_id = codec.event_type_range("program") + is_program = (tokens >= min_program_id) & (tokens <= max_program_id) + return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) + + +PROGRAM_GRANULARITIES = { + # "flat" granularity; drop program change tokens and set NoteSequence + # programs to zero + "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), + # map each program to the first program in its MIDI class + "midi_class": ProgramGranularity( + tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) + ), + # leave programs as is + "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), +} + + +def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): + """ + equivalent of tf.signal.frame + """ + signal_length = signal.shape[axis] + if pad_end: + frames_overlap = frame_length - frame_step + rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) + pad_size = int(frame_length - rest_samples) + + if pad_size != 0: + pad_axis = [0] * signal.ndim + pad_axis[axis] = pad_size + signal = F.pad(signal, pad_axis, "constant", pad_value) + frames = signal.unfold(axis, frame_length, frame_step) + return frames + + +def program_to_slakh_program(program): + # this is done very hackily, probably should use a custom mapping + for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): + if program >= slakh_program: + return slakh_program + + +def audio_to_frames( + samples, + hop_size: int, + frame_rate: int, +) -> tuple[Sequence[Sequence[int]], torch.Tensor]: + """Convert audio samples to non-overlapping frames and frame times.""" + frame_size = hop_size + samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") + + # Split audio into frames. + frames = frame( + torch.Tensor(samples).unsqueeze(0), + frame_length=frame_size, + frame_step=frame_size, + pad_end=False, # TODO check why its off by 1 here when True + ) + + num_frames = len(samples) // frame_size + + times = np.arange(num_frames) / frame_rate + return frames, times + + +def note_sequence_to_onsets_and_offsets_and_programs( + ns: note_seq.NoteSequence, +) -> tuple[Sequence[float], Sequence[NoteEventData]]: + """Extract onset & offset times and pitches & programs from a NoteSequence. + + The onset & offset times will not necessarily be in sorted order. + + Args: + ns: NoteSequence from which to extract onsets and offsets. + + Returns: + times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for + note + offsets. + """ + # Sort by program and pitch and put offsets before onsets as a tiebreaker for + # subsequent stable sort. + notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) + times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] + values = [ + NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) + for note in notes + if not note.is_drum + ] + [ + NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) + for note in notes + ] + return times, values + + +def num_velocity_bins_from_codec(codec: Codec): + """Get number of velocity bins from event codec.""" + lo, hi = codec.event_type_range("velocity") + return hi - lo + + +# segment an array into segments of length n +def segment(a, n): + return [a[i : i + n] for i in range(0, len(a), n)] + + +def velocity_to_bin(velocity, num_velocity_bins): + if velocity == 0: + return 0 + else: + return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) + + +def note_event_data_to_events( + state: NoteEncodingState | None, + value: NoteEventData, + codec: Codec, +) -> Sequence[Event]: + """Convert note event data to a sequence of events.""" + if value.velocity is None: + # onsets only, no program or velocity + return [Event("pitch", value.pitch)] + else: + num_velocity_bins = num_velocity_bins_from_codec(codec) + velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) + if value.program is None: + # onsets + offsets + velocities only, no programs + if state is not None: + state.active_pitches[(value.pitch, 0)] = velocity_bin + return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] + else: + if value.is_drum: + # drum events use a separate vocabulary + return [Event("velocity", velocity_bin), Event("drum", value.pitch)] + else: + # program + velocity + pitch + if state is not None: + state.active_pitches[(value.pitch, value.program)] = velocity_bin + return [ + Event("program", value.program), + Event("velocity", velocity_bin), + Event("pitch", value.pitch), + ] + + +def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: + """Output program and pitch events for active notes plus a final tie event.""" + events = [] + for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): + if state.active_pitches[(pitch, program)]: + events += [Event("program", program), Event("pitch", pitch)] + events.append(Event("tie", 0)) + return events + + +def encode_and_index_events( + state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None +): + """Encode a sequence of timed events and index to audio frame times. + + Encodes time shifts as repeated single step shifts for later run length encoding. + + Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio + frame. This can be used e.g. to prepend events representing the current state to a targets segment. + + Args: + state: Initial event encoding state. + event_times: Sequence of event times. + event_values: Sequence of event values. + encode_event_fn: Function that transforms event value into a sequence of one + or more Event objects. + codec: An Codec object that maps Event objects to indices. + frame_times: Time for every audio frame. + encoding_state_to_events_fn: Function that transforms encoding state into a + sequence of one or more Event objects. + + Returns: + events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. + Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes + splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of + another. + event_end_indices: Corresponding end event index for every audio frame. Used + to ensure when slicing that one chunk ends where the next begins. Should always be true that + event_end_indices[i] = event_start_indices[i + 1]. + state_events: Encoded "state" events representing the encoding state before + each event. + state_event_indices: Corresponding state event index for every audio frame. + """ + indices = np.argsort(event_times, kind="stable") + event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] + event_values = [event_values[i] for i in indices] + + events = [] + state_events = [] + event_start_indices = [] + state_event_indices = [] + + cur_step = 0 + cur_event_idx = 0 + cur_state_event_idx = 0 + + def fill_event_start_indices_to_cur_step(): + while ( + len(event_start_indices) < len(frame_times) + and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second + ): + event_start_indices.append(cur_event_idx) + state_event_indices.append(cur_state_event_idx) + + for event_step, event_value in zip(event_steps, event_values): + while event_step > cur_step: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + cur_state_event_idx = len(state_events) + if encoding_state_to_events_fn: + # Dump state to state events *before* processing the next event, because + # we want to capture the state prior to the occurrence of the event. + for e in encoding_state_to_events_fn(state): + state_events.append(codec.encode_event(e)) + + for e in encode_event_fn(state, event_value, codec): + events.append(codec.encode_event(e)) + + # After the last event, continue filling out the event_start_indices array. + # The inequality is not strict because if our current step lines up exactly + # with (the start of) an audio frame, we need to add an additional shift event + # to "cover" that frame. + while cur_step / codec.steps_per_second <= frame_times[-1]: + events.append(codec.encode_event(Event(type="shift", value=1))) + cur_step += 1 + fill_event_start_indices_to_cur_step() + cur_event_idx = len(events) + + # Now fill in event_end_indices. We need this extra array to make sure that + # when we slice events, each slice ends exactly where the subsequent slice + # begins. + event_end_indices = event_start_indices[1:] + [len(events)] + + events = np.array(events).astype(np.int32) + state_events = np.array(state_events).astype(np.int32) + event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) + + outputs = [] + for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): + outputs.append( + { + "inputs": events, + "event_start_indices": start_indices, + "event_end_indices": end_indices, + "state_events": state_events, + "state_event_indices": event_indices, + } + ) + + return outputs + + +def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): + """Extract target sequence corresponding to audio token segment.""" + features = features.copy() + start_idx = features["event_start_indices"][0] + end_idx = features["event_end_indices"][-1] + + features[feature_key] = features[feature_key][start_idx:end_idx] + + if state_events_end_token is not None: + # Extract the state events corresponding to the audio start token, and + # prepend them to the targets array. + state_event_start_idx = features["state_event_indices"][0] + state_event_end_idx = state_event_start_idx + 1 + while features["state_events"][state_event_end_idx - 1] != state_events_end_token: + state_event_end_idx += 1 + features[feature_key] = np.concatenate( + [ + features["state_events"][state_event_start_idx:state_event_end_idx], + features[feature_key], + ], + axis=0, + ) + + return features + + +def map_midi_programs( + feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" +) -> Mapping[str, Any]: + """Apply MIDI program map to token sequences.""" + granularity = PROGRAM_GRANULARITIES[granularity_type] + + feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) + return feature + + +def run_length_encode_shifts_fn( + features, + codec: Codec, + feature_key: str = "inputs", + state_change_event_types: Sequence[str] = (), +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + """Return a function that run-length encodes shifts for a given codec. + + Args: + codec: The Codec to use for shift events. + feature_key: The feature key for which to run-length encode shifts. + state_change_event_types: A list of event types that represent state + changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones + will be removed. + + Returns: + A preprocessing function that run-length encodes single-step shifts. + """ + state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] + + def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: + """Combine leading/interior shifts, trim trailing shifts. + + Args: + features: Dict of features to process. + + Returns: + A dict of features. + """ + events = features[feature_key] + + shift_steps = 0 + total_shift_steps = 0 + output = np.array([], dtype=np.int32) + + current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) + + for event in events: + if codec.is_shift_event_index(event): + shift_steps += 1 + total_shift_steps += 1 + + else: + # If this event is a state change and has the same value as the current + # state, we can skip it entirely. + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state[i] = event + if is_redundant: + continue + + # Once we've reached a non-shift event, RLE all previous shift events + # before outputting the non-shift event. + if shift_steps > 0: + shift_steps = total_shift_steps + while shift_steps > 0: + output_steps = np.minimum(codec.max_shift_steps, shift_steps) + output = np.concatenate([output, [output_steps]], axis=0) + shift_steps -= output_steps + output = np.concatenate([output, [event]], axis=0) + + features[feature_key] = output + return features + + return run_length_encode_shifts(features) + + +def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): + tie_token = codec.encode_event(Event("tie", 0)) + state_events_end_token = tie_token if note_representation_config.include_ties else None + + features = extract_sequence_with_indices( + features, state_events_end_token=state_events_end_token, feature_key="inputs" + ) + + features = map_midi_programs(features, codec) + + features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) + + return features + + +class MidiProcessor: + def __init__(self): + self.codec = Codec( + max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, + steps_per_second=DEFAULT_STEPS_PER_SECOND, + event_ranges=[ + EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), + EventRange("tie", 0, 0), + EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), + EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), + ], + ) + self.tokenizer = Tokenizer(self.codec.num_classes) + self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) + + def __call__(self, midi: bytes | os.PathLike | str): + if not isinstance(midi, bytes): + with open(midi, "rb") as f: + midi = f.read() + + ns = note_seq.midi_to_note_sequence(midi) + ns_sus = note_seq.apply_sustain_control_changes(ns) + + for note in ns_sus.notes: + if not note.is_drum: + note.program = program_to_slakh_program(note.program) + + samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) + + _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) + times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) + + events = encode_and_index_events( + state=NoteEncodingState(), + event_times=times, + event_values=values, + frame_times=frame_times, + codec=self.codec, + encode_event_fn=note_event_data_to_events, + encoding_state_to_events_fn=note_encoding_state_to_events, + ) + + events = [ + note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events + ] + input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] + + return input_tokens diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..25ad4a4ccfd28b2423ee9771130a24c941c06649 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py @@ -0,0 +1,86 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers.modeling_utils import ModuleUtilsMixin +from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin + + +class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + @register_to_config + def __init__( + self, + max_length: int, + vocab_size: int, + d_model: int, + dropout_rate: float, + num_layers: int, + num_heads: int, + d_kv: int, + d_ff: int, + feed_forward_proj: str, + is_decoder: bool = False, + ): + super().__init__() + + self.token_embedder = nn.Embedding(vocab_size, d_model) + + self.position_encoding = nn.Embedding(max_length, d_model) + self.position_encoding.weight.requires_grad = False + + self.dropout_pre = nn.Dropout(p=dropout_rate) + + t5config = T5Config( + vocab_size=vocab_size, + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + d_ff=d_ff, + dropout_rate=dropout_rate, + feed_forward_proj=feed_forward_proj, + is_decoder=is_decoder, + is_encoder_decoder=False, + ) + + self.encoders = nn.ModuleList() + for lyr_num in range(num_layers): + lyr = T5Block(t5config) + self.encoders.append(lyr) + + self.layer_norm = T5LayerNorm(d_model) + self.dropout_post = nn.Dropout(p=dropout_rate) + + def forward(self, encoder_input_tokens, encoder_inputs_mask): + x = self.token_embedder(encoder_input_tokens) + + seq_length = encoder_input_tokens.shape[1] + inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) + x += self.position_encoding(inputs_positions) + + x = self.dropout_pre(x) + + # inverted the attention mask + input_shape = encoder_input_tokens.size() + extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) + + for lyr in self.encoders: + x = lyr(x, extended_attention_mask)[0] + x = self.layer_norm(x) + + return self.dropout_post(x), encoder_inputs_mask diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..269e7405d10d2be9f44d481b05725b2918727ea9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py @@ -0,0 +1,269 @@ +# Copyright 2022 The Music Spectrogram Diffusion Authors. +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable + +import numpy as np +import torch + +from ....models import T5FilmDecoder +from ....schedulers import DDPMScheduler +from ....utils import is_onnx_available, logging +from ....utils.torch_utils import randn_tensor + + +if is_onnx_available(): + from ...onnx_utils import OnnxRuntimeModel + +from ...pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .continuous_encoder import SpectrogramContEncoder +from .notes_encoder import SpectrogramNotesEncoder + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TARGET_FEATURE_LENGTH = 256 + + +class SpectrogramDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for unconditional audio generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + notes_encoder ([`SpectrogramNotesEncoder`]): + continuous_encoder ([`SpectrogramContEncoder`]): + decoder ([`T5FilmDecoder`]): + A [`T5FilmDecoder`] to denoise the encoded audio latents. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with `decoder` to denoise the encoded audio latents. + melgan ([`OnnxRuntimeModel`]): + """ + + _optional_components = ["melgan"] + + def __init__( + self, + notes_encoder: SpectrogramNotesEncoder, + continuous_encoder: SpectrogramContEncoder, + decoder: T5FilmDecoder, + scheduler: DDPMScheduler, + melgan: OnnxRuntimeModel if is_onnx_available() else Any, + ) -> None: + super().__init__() + + # From MELGAN + self.min_value = math.log(1e-5) # Matches MelGAN training. + self.max_value = 4.0 # Largest value for most examples + self.n_dims = 128 + + self.register_modules( + notes_encoder=notes_encoder, + continuous_encoder=continuous_encoder, + decoder=decoder, + scheduler=scheduler, + melgan=melgan, + ) + + def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): + """Linearly scale features to network outputs range.""" + min_out, max_out = output_range + if clip: + features = torch.clip(features, self.min_value, self.max_value) + # Scale to [0, 1]. + zero_one = (features - self.min_value) / (self.max_value - self.min_value) + # Scale to [min_out, max_out]. + return zero_one * (max_out - min_out) + min_out + + def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): + """Invert by linearly scaling network outputs to features range.""" + min_out, max_out = input_range + outputs = torch.clip(outputs, min_out, max_out) if clip else outputs + # Scale to [0, 1]. + zero_one = (outputs - min_out) / (max_out - min_out) + # Scale to [self.min_value, self.max_value]. + return zero_one * (self.max_value - self.min_value) + self.min_value + + def encode(self, input_tokens, continuous_inputs, continuous_mask): + tokens_mask = input_tokens > 0 + tokens_encoded, tokens_mask = self.notes_encoder( + encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask + ) + + continuous_encoded, continuous_mask = self.continuous_encoder( + encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask + ) + + return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] + + def decode(self, encodings_and_masks, input_tokens, noise_time): + timesteps = noise_time + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(input_tokens.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) + + logits = self.decoder( + encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps + ) + return logits + + @torch.no_grad() + def __call__( + self, + input_tokens: list[list[int]], + generator: torch.Generator | None = None, + num_inference_steps: int = 100, + return_dict: bool = True, + output_type: str = "np", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ) -> AudioPipelineOutput | tuple: + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + r""" + The call function to the pipeline for generation. + + Args: + input_tokens (`list[list[int]]`): + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Example: + + ```py + >>> from diffusers import SpectrogramDiffusionPipeline, MidiProcessor + + >>> pipe = SpectrogramDiffusionPipeline.from_pretrained("google/music-spectrogram-diffusion") + >>> pipe = pipe.to("cuda") + >>> processor = MidiProcessor() + + >>> # Download MIDI from: wget http://www.piano-midi.de/midis/beethoven/beethoven_hammerklavier_2.mid + >>> output = pipe(processor("beethoven_hammerklavier_2.mid")) + + >>> audio = output.audios[0] + ``` + + Returns: + [`pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + + pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) + full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) + ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + + for i, encoder_input_tokens in enumerate(input_tokens): + if i == 0: + encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( + device=self.device, dtype=self.decoder.dtype + ) + # The first chunk has no previous context. + encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) + else: + # The full song pipeline does not feed in a context feature, so the mask + # will be all 0s after the feature converter. Because we know we're + # feeding in a full context chunk from the previous prediction, set it + # to all 1s. + encoder_continuous_mask = ones + + encoder_continuous_inputs = self.scale_features( + encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True + ) + + encodings_and_masks = self.encode( + input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), + continuous_inputs=encoder_continuous_inputs, + continuous_mask=encoder_continuous_mask, + ) + + # Sample encoder_continuous_inputs shaped gaussian noise to begin loop + x = randn_tensor( + shape=encoder_continuous_inputs.shape, + generator=generator, + device=self.device, + dtype=self.decoder.dtype, + ) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + # Denoising diffusion loop + for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + output = self.decode( + encodings_and_masks=encodings_and_masks, + input_tokens=x, + noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) + ) + + # Compute previous output: x_t -> x_t-1 + x = self.scheduler.step(output, t, x, generator=generator).prev_sample + + mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) + encoder_continuous_inputs = mel[:1] + pred_mel = mel.cpu().float().numpy() + + full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, full_pred_mel) + + logger.info("Generated segment", i) + + if output_type == "np" and not is_onnx_available(): + raise ValueError( + "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." + ) + elif output_type == "np" and self.melgan is None: + raise ValueError( + "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." + ) + + if output_type == "np": + output = self.melgan(input_features=full_pred_mel.astype(np.float32)) + else: + output = full_pred_mel + + if not return_dict: + return (output,) + + return AudioPipelineOutput(audios=output) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36cf1a33ce6ada8e718aabadb9a706737aee30bd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"] + _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"] + + _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"] + _import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_cycle_diffusion import CycleDiffusionPipeline + from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy + from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline + from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline + from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59ab5a72bb97681c00adb3d1375e0408d2c5cb91 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_cycle_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_cycle_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24c1a8d7ef64ccae3c603ee24eea86d785e4cdfd Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_cycle_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0920053990f4638296e94d83ad890a2ccedf33b8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_onnx_stable_diffusion_inpaint_legacy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b6599952929584458ae315e13d87b87f02c232 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_inpaint_legacy.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_model_editing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_model_editing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65a892ae437ca9e81727587979c5eb32bf875859 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_model_editing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_paradigms.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_paradigms.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334c81bc1cd7f69acee40c86a414b4592b558604 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_paradigms.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17cc7935cd479012263f8e35398dfa5d86e649e1 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/__pycache__/pipeline_stable_diffusion_pix2pix_zero.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..dae5e600d773a388b5fa5dc79dfde78005429554 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -0,0 +1,953 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import DDIMScheduler +from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + if prev_timestep <= 0: + return clean_latents + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # direction pointing to x_t + e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) + dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t + noise = std_dev_t * randn_tensor( + clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator + ) + prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise + + return prev_latents + + +def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( + variance ** (0.5) * eta + ) + return noise + + +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can only be an + instance of [`DDIMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + image = image.to(device=device, dtype=dtype) + + batch_size = image.shape[0] + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timestep + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + clean_latents = init_latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents, clean_latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + source_prompt: str | list[str], + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + source_guidance_scale: float | None = 1, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.1, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + source_guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale for the source prompt. This is useful to control the amount of influence the source + prompt has for encoding. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Example: + + ```py + import requests + import torch + from PIL import Image + from io import BytesIO + + from diffusers import CycleDiffusionPipeline, DDIMScheduler + + # load the pipeline + # make sure you're logged in with `hf auth login` + model_id_or_path = "CompVis/stable-diffusion-v1-4" + scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") + pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") + + # let's download an initial image + url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png" + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((512, 512)) + init_image.save("horse.png") + + # let's specify a prompt + source_prompt = "An astronaut riding a horse" + prompt = "An astronaut riding an elephant" + + # call the pipeline + image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.8, + guidance_scale=2, + source_guidance_scale=1, + ).images[0] + + image.save("horse_to_elephant.png") + + # let's try another example + # See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion + url = ( + "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png" + ) + response = requests.get(url) + init_image = Image.open(BytesIO(response.content)).convert("RGB") + init_image = init_image.resize((512, 512)) + init_image.save("black.png") + + source_prompt = "A black colored car" + prompt = "A blue colored car" + + # call the pipeline + torch.manual_seed(0) + image = pipe( + prompt=prompt, + source_prompt=source_prompt, + image=init_image, + num_inference_steps=100, + eta=0.1, + strength=0.85, + guidance_scale=3, + source_guidance_scale=1, + ).images[0] + + image.save("black_to_blue.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds_tuple = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + source_prompt_embeds_tuple = self.encode_prompt( + source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None, clip_skip=clip_skip + ) + if prompt_embeds_tuple[1] is not None: + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + else: + prompt_embeds = prompt_embeds_tuple[0] + if source_prompt_embeds_tuple[1] is not None: + source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]]) + else: + source_prompt_embeds = source_prompt_embeds_tuple[0] + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, clean_latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + source_latents = latents + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + generator = extra_step_kwargs.pop("generator", None) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + source_latent_model_input = ( + torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + + # predict the noise residual + if do_classifier_free_guidance: + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_prompt_embeds = torch.stack( + [ + source_prompt_embeds[0], + prompt_embeds[0], + source_prompt_embeds[1], + prompt_embeds[1], + ], + dim=0, + ) + else: + concat_latent_model_input = torch.cat( + [ + source_latent_model_input, + latent_model_input, + ], + dim=0, + ) + concat_prompt_embeds = torch.cat( + [ + source_prompt_embeds, + prompt_embeds, + ], + dim=0, + ) + + concat_noise_pred = self.unet( + concat_latent_model_input, + t, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=concat_prompt_embeds, + ).sample + + # perform guidance + if do_classifier_free_guidance: + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) + + else: + (source_noise_pred, noise_pred) = concat_noise_pred.chunk(2, dim=0) + + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..f526dc419cea63d5e60fadae61a10c07f266dc1a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,543 @@ +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ....utils import deprecate, logging +from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to + provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int | None, + do_classifier_free_guidance: bool, + negative_prompt: str | None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: str | list[str], + image: np.ndarray | PIL.Image.Image = None, + mask_image: np.ndarray | PIL.Image.Image = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: np.random.RandomState | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (?) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, np.ndarray): + mask_image = preprocess_mask(mask_image, 8) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ? in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + + latents = latents.numpy() + + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) + ) + + init_latents_proper = init_latents_proper.numpy() + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..650695b604c1014f1b02d78ace239e82029014ee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,789 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import FrozenDict +from ....image_processor import VaeImageProcessor +from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline +from ...stable_diffusion import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) + + +def preprocess_image(image, batch_size): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, batch_size, scale_factor=8): + if not isinstance(mask, torch.Tensor): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = np.vstack([mask[None]] * batch_size) + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + else: + valid_mask_channel_sizes = [1, 3] + # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) + if mask.shape[3] in valid_mask_channel_sizes: + mask = mask.permute(0, 3, 1, 2) + elif mask.shape[1] not in valid_mask_channel_sizes: + raise ValueError( + f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," + f" but received mask of shape {tuple(mask.shape)}" + ) + # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape + mask = mask.mean(dim=1, keepdim=True) + h, w = mask.shape[-2:] + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 + mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) + return mask + + +class StableDiffusionInpaintPipelineLegacy( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): + image = image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = self.vae.config.scaling_factor * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # add noise to latents using the timesteps + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + return latents, init_latents_orig, noise + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: torch.Tensor | PIL.Image.Image = None, + mask_image: torch.Tensor | PIL.Image.Image = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + add_predicted_noise: bool | None = False, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`torch.Tensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the + expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to + that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The reference number of denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. This parameter will be modulated by `strength`, as explained above. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess image and mask + if not isinstance(image, torch.Tensor): + image = preprocess_image(image, batch_size) + + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + # encode the init image into latents and scale the latents + latents, init_latents_orig, noise = self.prepare_latents( + image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare mask latent + mask = mask_image.to(device=device, dtype=latents.dtype) + mask = torch.cat([mask] * num_images_per_prompt) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # masking + if add_predicted_noise: + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise_pred_uncond, torch.tensor([t]) + ) + else: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # use original latents corresponding to unmasked portions of the image + latents = (init_latents_orig * mask) + (latents * (1 - mask)) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py new file mode 100644 index 0000000000000000000000000000000000000000..851820c00aedf6089aa11c1c59d4d663fa41435b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -0,0 +1,832 @@ +# Copyright 2025 TIME Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import PNDMScheduler +from ....schedulers.scheduling_utils import SchedulerMixin +from ....utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] + + +class StableDiffusionModelEditingPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + r""" + Pipeline for text-to-image model editing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + with_to_k ([`bool`]): + Whether to edit the key projection matrices along with the value projection matrices. + with_augs ([`list`]): + Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + with_to_k: bool = True, + with_augs: list = AUGS_CONST, + ): + super().__init__() + + if isinstance(scheduler, PNDMScheduler): + logger.error("PNDMScheduler for this pipeline is currently not supported.") + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.with_to_k = with_to_k + self.with_augs = with_augs + + # get cross-attention layers + ca_layers = [] + + def append_ca(net_): + if net_.__class__.__name__ == "CrossAttention": + ca_layers.append(net_) + elif hasattr(net_, "children"): + for net__ in net_.children(): + append_ca(net__) + + # recursively find all cross-attention layers in unet + for net in self.unet.named_children(): + if "down" in net[0]: + append_ca(net[1]) + elif "up" in net[0]: + append_ca(net[1]) + elif "mid" in net[0]: + append_ca(net[1]) + + # get projection matrices + self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] + self.projection_matrices = [l.to_v for l in self.ca_clip_layers] + self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] + if self.with_to_k: + self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] + self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def edit_model( + self, + source_prompt: str, + destination_prompt: str, + lamb: float = 0.1, + restart_params: bool = True, + ): + r""" + Apply model editing via closed-form solution (see Eq. 5 in the TIME + [paper](https://huggingface.co/papers/2303.08084)). + + Args: + source_prompt (`str`): + The source prompt containing the concept to be edited. + destination_prompt (`str`): + The destination prompt. Must contain all words from `source_prompt` with additional ones to specify the + target edit. + lamb (`float`, *optional*, defaults to 0.1): + The lambda parameter specifying the regularization intensity. Smaller values increase the editing + power. + restart_params (`bool`, *optional*, defaults to True): + Restart the model parameters to their pre-trained version before editing. This is done to avoid edit + compounding. When it is `False`, edits accumulate. + """ + + # restart LDM parameters + if restart_params: + num_ca_clip_layers = len(self.ca_clip_layers) + for idx_, l in enumerate(self.ca_clip_layers): + l.to_v = copy.deepcopy(self.og_matrices[idx_]) + self.projection_matrices[idx_] = l.to_v + if self.with_to_k: + l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) + self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k + + # set up sentences + old_texts = [source_prompt] + new_texts = [destination_prompt] + # add augmentations + base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] + for aug in self.with_augs: + old_texts.append(aug + base) + base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] + for aug in self.with_augs: + new_texts.append(aug + base) + + # prepare input k* and v* + old_embs, new_embs = [], [] + for old_text, new_text in zip(old_texts, new_texts): + text_input = self.tokenizer( + [old_text, new_text], + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + old_emb, new_emb = text_embeddings + old_embs.append(old_emb) + new_embs.append(new_emb) + + # identify corresponding destinations for each token in old_emb + idxs_replaces = [] + for old_text, new_text in zip(old_texts, new_texts): + tokens_a = self.tokenizer(old_text).input_ids + tokens_b = self.tokenizer(new_text).input_ids + tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] + tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] + num_orig_tokens = len(tokens_a) + idxs_replace = [] + j = 0 + for i in range(num_orig_tokens): + curr_token = tokens_a[i] + while tokens_b[j] != curr_token: + j += 1 + idxs_replace.append(j) + j += 1 + while j < 77: + idxs_replace.append(j) + j += 1 + while len(idxs_replace) < 77: + idxs_replace.append(76) + idxs_replaces.append(idxs_replace) + + # prepare batch: for each pair of sentences, old context and new values + contexts, valuess = [], [] + for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): + context = old_emb.detach() + values = [] + with torch.no_grad(): + for layer in self.projection_matrices: + values.append(layer(new_emb[idxs_replace]).detach()) + contexts.append(context) + valuess.append(values) + + # edit the model + for layer_num in range(len(self.projection_matrices)): + # mat1 = \lambda W + \sum{v k^T} + mat1 = lamb * self.projection_matrices[layer_num].weight + + # mat2 = \lambda I + \sum{k k^T} + mat2 = lamb * torch.eye( + self.projection_matrices[layer_num].weight.shape[1], + device=self.projection_matrices[layer_num].weight.device, + ) + + # aggregate sums for mat1, mat2 + for context, values in zip(contexts, valuess): + context_vector = context.reshape(context.shape[0], context.shape[1], 1) + context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) + value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) + for_mat1 = (value_vector @ context_vector_T).sum(dim=0) + for_mat2 = (context_vector @ context_vector_T).sum(dim=0) + mat1 += for_mat1 + mat2 += for_mat2 + + # update projection matrix + self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + ```py + >>> import torch + >>> from diffusers import StableDiffusionModelEditingPipeline + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) + + >>> pipe = pipe.to("cuda") + + >>> source_prompt = "A pack of roses" + >>> destination_prompt = "A pack of blue roses" + >>> pipe.edit_model(source_prompt, destination_prompt) + + >>> prompt = "A field of roses" + >>> image = pipe(prompt).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py new file mode 100644 index 0000000000000000000000000000000000000000..ea81be87a0f4b246a0b328e050da5a07357141e9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py @@ -0,0 +1,798 @@ +# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DDPMParallelScheduler + >>> from diffusers import StableDiffusionParadigmsPipeline + + >>> scheduler = DDPMParallelScheduler.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler" + ... ) + + >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 + >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] + ``` +""" + + +class StableDiffusionParadigmsPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using a parallelized version of Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs + self.wrapped_unet = self.unet + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _cumsum(self, input, dim, debug=False): + if debug: + # cumsum_cuda_kernel does not have a deterministic implementation + # so perform cumsum on cpu for debugging purposes + return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) + else: + return torch.cumsum(input, dim=dim) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + parallel: int = 10, + tolerance: float = 0.1, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + debug: bool = False, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + parallel (`int`, *optional*, defaults to 10): + The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but + requires higher memory usage and can also require more total FLOPs. + tolerance (`float`, *optional*, defaults to 0.1): + The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower + tolerance usually leads to less or no degradation. Higher tolerance is faster but can risk degradation + of sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + debug (`bool`, *optional*, defaults to `False`): + Whether or not to run in debug mode. In debug mode, `torch.cumsum` is evaluated using the CPU. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + extra_step_kwargs.pop("generator", None) + + # # 7. Denoising loop + scheduler = self.scheduler + parallel = min(parallel, len(scheduler.timesteps)) + + begin_idx = 0 + end_idx = parallel + latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) + + # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. + # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. + noise_array = torch.zeros_like(latents_time_evolution_buffer) + for j in range(len(scheduler.timesteps)): + base_noise = randn_tensor( + shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype + ) + noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise + noise_array[j] = noise.clone() + + # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance + # outside of the denoising loop to avoid recomputing it at every step. + # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. + inverse_variance_norm = 1.0 / torch.tensor( + [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] + ).to(noise_array.device) + latent_dim = noise_array[0, 0].numel() + inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim + + scaled_tolerance = tolerance**2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + steps = 0 + while begin_idx < len(scheduler.timesteps): + # these have shape (parallel_dim, 2*batch_size, ...) + # parallel_len is at most parallel, but could be less if we are at the end of the timesteps + # we are processing batch window of timesteps spanning [begin_idx, end_idx) + parallel_len = end_idx - begin_idx + + block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) + block_latents = latents_time_evolution_buffer[begin_idx:end_idx] + block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) + t_vec = block_t + if do_classifier_free_guidance: + t_vec = t_vec.repeat(1, 2) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) + + # if parallel_len is small, no need to use multiple GPUs + net = self.wrapped_unet if parallel_len > 3 else self.unet + # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] + model_output = net( + latent_model_input.flatten(0, 1), + t_vec.flatten(0, 1), + encoder_hidden_states=block_prompt_embeds.flatten(0, 1), + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + per_latent_shape = model_output.shape[1:] + if do_classifier_free_guidance: + model_output = model_output.reshape( + parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape + ) + noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] + model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + model_output = model_output.reshape( + parallel_len * batch_size * num_images_per_prompt, *per_latent_shape + ) + + block_latents_denoise = scheduler.batch_step_no_noise( + model_output=model_output, + timesteps=block_t.flatten(0, 1), + sample=block_latents.flatten(0, 1), + **extra_step_kwargs, + ).reshape(block_latents.shape) + + # back to shape (parallel_dim, batch_size, ...) + # now we want to add the pre-sampled noise + # parallel sampling algorithm requires computing the cumulative drift from the beginning + # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. + delta = block_latents_denoise - block_latents + cumulative_delta = self._cumsum(delta, dim=0, debug=debug) + cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) + + # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise + if scheduler._is_ode_scheduler: + cumulative_noise = 0 + + block_latents_new = ( + latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise + ) + cur_error = torch.linalg.norm( + (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( + parallel_len, batch_size * num_images_per_prompt, -1 + ), + dim=-1, + ).pow(2) + error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] + + # find the first index of the vector error_ratio that is greater than error tolerance + # we can shift the window for the next iteration up to this index + error_ratio = torch.nn.functional.pad( + error_ratio, (0, 0, 0, 1), value=1e9 + ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension + any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() + ind = torch.argmax(any_error_at_time).item() + + # compute the new begin and end idxs for the window + new_begin_idx = begin_idx + min(1 + ind, parallel) + new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) + + # store the computed latents for the current window in the global buffer + latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new + # initialize the new sliding window latents with the end of the current window, + # should be better than random initialization + latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ + None, + ] + + steps += 1 + + progress_bar.update(new_begin_idx - begin_idx) + if callback is not None and steps % callback_steps == 0: + callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) + + begin_idx = new_begin_idx + end_idx = new_end_idx + + latents = latents_time_evolution_buffer[-1] + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..0955b6fe48a149af5f76ce717729fbb21a1238a1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py @@ -0,0 +1,1311 @@ +# Copyright 2025 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, +) + +from ....image_processor import PipelineImageInput, VaeImageProcessor +from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ....models import AutoencoderKL, UNet2DConditionModel +from ....models.attention_processor import Attention +from ....models.lora import adjust_lora_scale_text_encoder +from ....schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler +from ....schedulers.scheduling_ddim_inverse import DDIMInverseScheduler +from ....utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.Tensor`) + inverted latents tensor + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + latents: torch.Tensor + images: list[PIL.Image.Image] | np.ndarray + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + + >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline + + + >>> def download(embedding_url, local_filepath): + ... r = requests.get(embedding_url) + ... with open(local_filepath, "wb") as f: + ... f.write(r.content) + + + >>> model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.to("cuda") + + >>> prompt = "a high resolution painting of a cat in the style of van gough" + >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" + >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" + + >>> for url in [source_emb_url, target_emb_url]: + ... download(url, url.split("/")[-1]) + + >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) + >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) + >>> images = pipeline( + ... prompt, + ... source_embeds=src_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... ).images + + >>> images[0].save("edited_image_dog.png") + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import BlipForConditionalGeneration, BlipProcessor + >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + + >>> import requests + >>> from PIL import Image + + >>> captioner_id = "Salesforce/blip-image-captioning-base" + >>> processor = BlipProcessor.from_pretrained(captioner_id) + >>> model = BlipForConditionalGeneration.from_pretrained( + ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ... ) + + >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + ... sd_model_ckpt, + ... caption_generator=model, + ... caption_processor=processor, + ... torch_dtype=torch.float16, + ... safety_checker=None, + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + >>> # generate caption + >>> caption = pipeline.generate_caption(raw_image) + + >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" + >>> inv_latents = pipeline.invert(caption, image=raw_image).latents + >>> # we need to generate source and target embeds + + >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + + >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + >>> source_embeds = pipeline.get_embeds(source_prompts) + >>> target_embeds = pipeline.get_embeds(target_prompts) + >>> # the latents can then be used to edit a real image + >>> # when using Stable Diffusion 2 or other models that use v-prediction + >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion + + >>> image = pipeline( + ... caption, + ... source_embeds=source_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... generator=generator, + ... latents=inv_latents, + ... negative_prompt=caption, + ... ).images[0] + >>> image.save("edited_image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def prepare_unet(unet: UNet2DConditionModel): + """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" + pix2pix_zero_attn_procs = {} + for name in unet.attn_processors.keys(): + module_name = name.replace(".processor", "") + module = unet.get_submodule(module_name) + if "attn2" in name: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) + module.requires_grad_(True) + else: + pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) + module.requires_grad_(False) + + unet.set_attn_processor(pix2pix_zero_attn_procs) + return unet + + +class Pix2PixZeroL2Loss: + def __init__(self): + self.loss = 0.0 + + def compute_loss(self, predictions, targets): + self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) + + +class Pix2PixZeroAttnProcessor: + """An attention processor class to store the attention weights. + In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" + + def __init__(self, is_pix2pix_zero=False): + self.is_pix2pix_zero = is_pix2pix_zero + if self.is_pix2pix_zero: + self.reference_cross_attn_map = {} + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + timestep=None, + loss=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + if self.is_pix2pix_zero and timestep is not None: + # new bookkeeping to save the attention weights. + if loss is None: + self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() + # compute loss + elif loss is not None: + prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) + loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + requires_safety_checker (bool): + Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the + pipeline publicly. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "caption_generator", + "caption_processor", + "inverse_scheduler", + ] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler | DDIMScheduler | EulerAncestralDiscreteScheduler | LMSDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + safety_checker: StableDiffusionSafetyChecker, + inverse_scheduler: DDIMInverseScheduler, + caption_generator: BlipForConditionalGeneration, + caption_processor: BlipProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + caption_processor=caption_processor, + caption_generator=caption_generator, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if source_embeds is None and target_embeds is None: + raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def generate_caption(self, images): + """Generates caption for a given image.""" + text = "a photography of" + + prev_device = self.caption_generator.device + + device = self._execution_device + inputs = self.caption_processor(images, text, return_tensors="pt").to( + device=device, dtype=self.caption_generator.dtype + ) + self.caption_generator.to(device) + outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + self.caption_generator.to(prev_device) + + caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption + + def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): + """Constructs the edit direction to steer the image generation process semantically.""" + return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) + + @torch.no_grad() + def get_embeds(self, prompt: list[str], batch_size: int = 16) -> torch.Tensor: + num_prompts = len(prompt) + embeds = [] + for i in range(0, num_prompts, batch_size): + prompt_slice = prompt[i : i + batch_size] + + input_ids = self.tokenizer( + prompt_slice, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + input_ids = input_ids.to(self.text_encoder.device) + embeds.append(self.text_encoder(input_ids)[0]) + + return torch.cat(embeds, dim=0).mean(0)[None] + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + def auto_corr_loss(self, hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + return reg_loss + + def kl_divergence(self, hidden_states): + mean = hidden_states.mean() + var = hidden_states.var() + return var + mean**2 - 1 - torch.log(var + 1e-7) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + source_embeds: torch.Tensor = None, + target_embeds: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + cross_attention_guidance_amount: float = 0.1, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + source_embeds (`torch.Tensor`): + Source concept embeddings. Generation of the embeddings as per the [original + paper](https://huggingface.co/papers/2302.03027). Used in discovering the edit direction. + target_embeds (`torch.Tensor`): + Target concept embeddings. Generation of the embeddings as per the [original + paper](https://huggingface.co/papers/2302.03027). Used in discovering the edit direction. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Define the spatial resolutions. + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + source_embeds, + target_embeds, + callback_steps, + prompt_embeds, + ) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Generate the inverted noise from the input image or any other image + # generated from the input prompt. + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + latents_init = latents.clone() + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 8. Compute the edit directions. + edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) + + # 9. Edit the prompt embeddings as per the edit directions discovered. + prompt_embeds_edit = prompt_embeds.clone() + prompt_embeds_edit[1:2] += edit_direction + + # 10. Second denoising loop to generate the edited image. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + latents = latents_init + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # we want to learn the latent such that it steers the generation + # process towards the edited direction, so make the make initial + # noise learnable + x_in = latent_model_input.detach().clone() + x_in.requires_grad = True + + # optimizer + opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) + + with torch.enable_grad(): + # initialize loss + loss = Pix2PixZeroL2Loss() + + # predict the noise residual + noise_pred = self.unet( + x_in, + t, + encoder_hidden_states=prompt_embeds_edit.detach(), + cross_attention_kwargs={"timestep": t, "loss": loss}, + ).sample + + loss.loss.backward(retain_graph=False) + opt.step() + + # recompute the noise + noise_pred = self.unet( + x_in.detach(), + t, + encoder_hidden_states=prompt_embeds_edit, + cross_attention_kwargs={"timestep": None}, + ).sample + + latents = x_in.detach().chunk(2)[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: str | None = None, + image: PipelineImageInput = None, + num_inference_steps: int = 50, + guidance_scale: float = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + cross_attention_guidance_amount: float = 0.1, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 5, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch which will be used for conditioning. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 5): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or + `tuple`: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted + latents tensor and then second is the corresponding decoded image. + """ + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) + + # 5. Encode input prompt + num_images_per_prompt = 1 + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.inverse_scheduler.timesteps + + # 6. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = self.auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = self.kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + inverted_latents = latents.detach().clone() + + # 8. Post-processing + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (inverted_latents, image) + + return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15c9a8c27f98dd7e1913bd57dfd5e8dae71172b4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ....utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_stochastic_karras_ve": ["KarrasVePipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_stochastic_karras_ve import KarrasVePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..025f333587518638e4d6a7ff98468413ed149f53 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e4b99a3ee979091506b8ebbd34d293b6f8d8fd0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/__pycache__/pipeline_stochastic_karras_ve.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2c785c8d98fc6bbdfcc25d2a401ea1b725a5eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -0,0 +1,126 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ....models import UNet2DModel +from ....schedulers import KarrasVeScheduler +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class KarrasVePipeline(DiffusionPipeline): + r""" + Pipeline for unconditional image generation. + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`KarrasVeScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. + """ + + # add type hints for linting + unet: UNet2DModel + scheduler: KarrasVeScheduler + + def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + img_size = self.unet.config.sample_size + shape = (batch_size, 3, img_size, img_size) + + model = self.unet + + # sample x_0 ~ N(0, sigma_0^2 * I) + sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # here sigma_t == t_i from the paper + sigma = self.scheduler.schedule[t] + sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 + + # 1. Select temporarily increased noise level sigma_hat + # 2. Add new noise to move from sample_i to sample_hat + sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) + + # 3. Predict the noise residual given the noise magnitude `sigma_hat` + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample + + # 4. Evaluate dx/dt at sigma_hat + # 5. Take Euler step from sigma to sigma_prev + step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) + + if sigma_prev != 0: + # 6. Apply 2nd order correction + # The model inputs and output are adjusted by following eq. (213) in [1]. + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample + step_output = self.scheduler.step_correct( + model_output, + sigma_hat, + sigma_prev, + sample_hat, + step_output.prev_sample, + step_output["derivative"], + ) + sample = step_output.prev_sample + + sample = (sample / 2 + 0.5).clamp(0, 1) + image = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea6ef6e2f65b96aebebdf72cb80135003e4f08d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__init__.py @@ -0,0 +1,71 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + + _dummy_objects.update( + { + "VersatileDiffusionDualGuidedPipeline": VersatileDiffusionDualGuidedPipeline, + "VersatileDiffusionImageVariationPipeline": VersatileDiffusionImageVariationPipeline, + "VersatileDiffusionPipeline": VersatileDiffusionPipeline, + "VersatileDiffusionTextToImagePipeline": VersatileDiffusionTextToImagePipeline, + } + ) +else: + _import_structure["modeling_text_unet"] = ["UNetFlatConditionModel"] + _import_structure["pipeline_versatile_diffusion"] = ["VersatileDiffusionPipeline"] + _import_structure["pipeline_versatile_diffusion_dual_guided"] = ["VersatileDiffusionDualGuidedPipeline"] + _import_structure["pipeline_versatile_diffusion_image_variation"] = ["VersatileDiffusionImageVariationPipeline"] + _import_structure["pipeline_versatile_diffusion_text_to_image"] = ["VersatileDiffusionTextToImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + else: + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..587b1054aee3e9e0fa0964bc3e59b699c6225d70 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/modeling_text_unet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/modeling_text_unet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53210ac103c03526ef048313297cac2ec26a2ce9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/modeling_text_unet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fce7707daab9344ec3f89661f26e6c676cba78a9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..307cc21044c661f9746999c798bf6538bd74c7b6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_dual_guided.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d55cb01e7ed318c7ff6bc26c97750693e730299 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_image_variation.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0adb04fe630634679c710780819959197ad7c5a5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/__pycache__/pipeline_versatile_diffusion_text_to_image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..7be159d77af5ddf3392fbae4c1b808d953bdb801 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,2438 @@ +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.utils import deprecate + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin +from ....models.activations import get_activation +from ....models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + AttnProcessor, +) +from ....models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from ....models.resnet import ResnetBlockCondNorm2D +from ....models.transformers.dual_transformer_2d import DualTransformer2DModel +from ....models.transformers.transformer_2d import Transformer2DModel +from ....models.unets.unet_2d_condition import UNet2DConditionOutput +from ....utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ....utils.torch_utils import apply_freeu + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + transformer_layers_per_block, + attention_type, + attention_head_dim, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + dropout=0.0, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + num_attention_heads, + transformer_layers_per_block, + resolution_idx, + attention_type, + attention_head_dim, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + dropout=0.0, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +class FourierEmbedder(nn.Module): + def __init__(self, num_freqs=64, temperature=100): + super().__init__() + + self.num_freqs = num_freqs + self.temperature = temperature + + freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) + freq_bands = freq_bands[None, None, None] + self.register_buffer("freq_bands", freq_bands, persistent=False) + + def __call__(self, x): + x = self.freq_bands * x.unsqueeze(-1) + return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1) + + +class GLIGENTextBoundingboxProjection(nn.Module): + def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8): + super().__init__() + self.positive_len = positive_len + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy + + if isinstance(out_dim, tuple): + out_dim = out_dim[0] + + if feature_type == "text-only": + self.linears = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + elif feature_type == "text-image": + self.linears_text = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.linears_image = nn.Sequential( + nn.Linear(self.positive_len + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) + + self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) + + def forward( + self, + boxes, + masks, + positive_embeddings=None, + phrases_masks=None, + image_masks=None, + phrases_embeddings=None, + image_embeddings=None, + ): + masks = masks.unsqueeze(-1) + + xyxy_embedding = self.fourier_embedder(boxes) + xyxy_null = self.null_position_feature.view(1, 1, -1) + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + if positive_embeddings: + positive_null = self.null_positive_feature.view(1, 1, -1) + positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null + + objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + else: + phrases_masks = phrases_masks.unsqueeze(-1) + image_masks = image_masks.unsqueeze(-1) + + text_null = self.null_text_feature.view(1, 1, -1) + image_null = self.null_image_feature.view(1, 1, -1) + + phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null + image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null + + objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) + objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) + objs = torch.cat([objs_text, objs_image], dim=1) + + return objs + + +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or + `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `tuple[int]`, or `tuple[tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + reverse_transformer_layers_per_block : (`tuple[tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `tuple[tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], + [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"] + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + mid_block_type: str = "UNetMidBlockFlatCrossAttn", + up_block_types: tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: bool | tuple[bool] = False, + block_out_channels: tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int | tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int | tuple[int] = 1280, + transformer_layers_per_block: int | tuple[int] | tuple[tuple] = 1, + reverse_transformer_layers_per_block: tuple[tuple[int]] | None = None, + encoder_hid_dim: int | None = None, + encoder_hid_dim_type: str | None = None, + attention_head_dim: int | tuple[int] = 8, + num_attention_heads: int | tuple[int] | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + addition_embed_type: str | None = None, + addition_time_embed_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: bool | None = None, + cross_attention_norm: str | None = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = LinearMultiDim( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlockFlatCrossAttn": + self.mid_block = UNetMidBlockFlatCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": + self.mid_block = UNetMidBlockFlatSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlockFlat": + self.mid_block = UNetMidBlockFlat( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = LinearMultiDim( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: AttentionProcessor | dict[str, AttentionProcessor]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: list[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def enable_freeu(self, s1, s2, b1, b2): + r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor | float | int, + encoder_hidden_states: torch.Tensor, + class_labels: torch.Tensor | None = None, + timestep_cond: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + down_intrablock_additional_residuals: tuple[torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + return_dict: bool = True, + ) -> UNet2DConditionOutput | tuple: + r""" + The [`UNetFlatConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlockFlat + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states = () + + for resnet in self.resnets: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + additional_residuals: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + upsample_size: int | None = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: tuple[torch.Tensor, ...], + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + upsample_size: int | None = None, + attention_mask: torch.Tensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlat(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`int | None`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, + height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: int | None = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int | tuple[int] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: int | None = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: str | None = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..101a1b72e7f95e4d28b68964f89d3e7566ccfe57 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,421 @@ +import inspect +from typing import Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import logging +from ...pipeline_utils import DiffusionPipeline +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + @torch.no_grad() + def image_variation( + self, + image: torch.Tensor | PIL.Image.Image, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `list[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.image_variation(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: PIL.Image.Image | list[PIL.Image.Image], + image: str | list[str], + text_to_image_strength: float = 0.5, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8fb712e023b34b84e380f73fd1b98c5573e8ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,561 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +import torch.utils.checkpoint +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + Pipeline for image-text dual-guided generation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPImageProcessor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + self.image_unet.register_to_config(dual_cross_attention=False) + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: PIL.Image.Image | list[PIL.Image.Image], + image: str | list[str], + text_to_image_strength: float = 0.5, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..348417ad11dfe703592c7452d2e10a4575d625a4 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,401 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + Pipeline for image variation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + image_feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + def __init__( + self, + image_feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: + prompt = list(prompt) + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: list[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + negative_prompt_embeds = self.image_encoder(pixel_values) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `list[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2af1063421dfac5e911835fa8b0c0862fbe5ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,479 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer + +from ....image_processor import VaeImageProcessor +from ....models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel +from ....schedulers import KarrasDiffusionSchedulers +from ....utils import deprecate, logging +from ....utils.torch_utils import randn_tensor +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Versatile Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPImageProcessor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: KarrasDiffusionSchedulers + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = normalize_embeddings(prompt_embeds) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.image_unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..070903377c7188415af0417d4839d74a8a34dc01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ....utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + LearnedClassifierFreeSamplingEmbeddings, + VQDiffusionPipeline, + ) + + _dummy_objects.update( + { + "LearnedClassifierFreeSamplingEmbeddings": LearnedClassifierFreeSamplingEmbeddings, + "VQDiffusionPipeline": VQDiffusionPipeline, + } + ) +else: + _import_structure["pipeline_vq_diffusion"] = ["LearnedClassifierFreeSamplingEmbeddings", "VQDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ....utils.dummy_torch_and_transformers_objects import ( + LearnedClassifierFreeSamplingEmbeddings, + VQDiffusionPipeline, + ) + else: + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae3c631ba821f521a7c4825fdfb831d12c4bacb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd6dcb63f5b59d39600779a4cf438ad0eeafd03a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/__pycache__/pipeline_vq_diffusion.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3b9512e888fce4ba9b639987f74da7a63bb755 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py @@ -0,0 +1,325 @@ +# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ....configuration_utils import ConfigMixin, register_to_config +from ....models import ModelMixin, Transformer2DModel, VQModel +from ....schedulers import VQDiffusionScheduler +from ....utils import logging +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: int | None = None, length: int | None = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + +class VQDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using VQ Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vqvae ([`VQModel`]): + Vector Quantized Variational Auto-Encoder (VAE) model to encode and decode images to and from latent + representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + transformer ([`Transformer2DModel`]): + A conditional `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`VQDiffusionScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + vqvae: VQModel + text_encoder: CLIPTextModel + tokenizer: CLIPTokenizer + transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings + scheduler: VQDiffusionScheduler + + def __init__( + self, + vqvae: VQModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: Transformer2DModel, + scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings + negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + num_inference_steps: int = 100, + guidance_scale: float = 5.0, + truncation_rate: float = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ) -> ImagePipelineOutput | tuple: + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): + Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at + most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above + `truncation_rate` are set to zero. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor` of shape (batch), *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Must be valid embedding indices.If not provided, a latents tensor will be generated of + completely masked latent pixels. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # get the initial completely masked latents unless the user supplied it + + latents_shape = (batch_size, self.transformer.num_latent_pixels) + if latents is None: + mask_class = self.transformer.num_vector_embeds - 1 + latents = torch.full(latents_shape, mask_class).to(self.device) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): + raise ValueError( + "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," + f" {self.transformer.num_vector_embeds - 1} (inclusive)." + ) + latents = latents.to(self.device) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + sample = latents + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + + # predict the un-noised image + # model_output == `log_p_x_0` + model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) + + model_output = self.truncate(model_output, truncation_rate) + + # remove `log(0)`'s (`-inf`s) + model_output = model_output.clamp(-70) + + # compute the previous noisy sample x_t -> x_t-1 + sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, sample) + + embedding_channels = self.vqvae.config.vq_embed_dim + embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) + embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) + image = self.vqvae.decode(embeddings, force_not_quantize=True).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def truncate(self, log_p_x_0: torch.Tensor, truncation_rate: float) -> torch.Tensor: + """ + Truncates `log_p_x_0` such that for each column vector, the total cumulative probability is `truncation_rate` + The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to + zero. + """ + sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) + sorted_p_x_0 = torch.exp(sorted_log_p_x_0) + keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate + + # Ensure that at least the largest probability is not zeroed out + all_true = torch.full_like(keep_mask[:, 0:1, :], True) + keep_mask = torch.cat((all_true, keep_mask), dim=1) + keep_mask = keep_mask[:, :-1, :] + + keep_mask = keep_mask.gather(1, indices.argsort(1)) + + rv = log_p_x_0.clone() + + rv[~keep_mask] = -torch.inf # -inf = log(0) + + return rv diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2a94f3cba77d867f97111a41895918842adc27 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__init__.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from ...utils import DIFFUSERS_SLOW_IMPORT, _LazyModule + + +_import_structure = {"pipeline_dit": ["DiTPipeline"]} + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + from .pipeline_dit import DiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81fde2a1907b012295081144e03cde3c6b0cae17 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..190dfe10a5ed3dc8d84ab47e1fc6891186454c34 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/__pycache__/pipeline_dit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/pipeline_dit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/pipeline_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..505360fc78a71a0569bff01b6bc84bdb3ccf7011 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/dit/pipeline_dit.py @@ -0,0 +1,249 @@ +# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +# William Peebles and Saining Xie +# +# Copyright (c) 2021 OpenAI +# MIT License +# +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...models import AutoencoderKL, DiTTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class DiTPipeline(DiffusionPipeline): + r""" + Pipeline for image generation based on a Transformer backbone instead of a UNet. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + transformer ([`DiTTransformer2DModel`]): + A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. Initially published as + [`Transformer2DModel`](https://huggingface.co/facebook/DiT-XL-2-256/blob/main/transformer/config.json#L2) + in the config, but the mismatch can be ignored. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "transformer->vae" + + def __init__( + self, + transformer: DiTTransformer2DModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + id2label: dict[int, str] | None = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + # create a imagenet -> id dictionary for easier use + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.lstrip().rstrip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + def get_label_ids(self, label: str | list[str]) -> list[int]: + r""" + + Map label strings from ImageNet to corresponding class ids. + + Parameters: + label (`str` or `dict` of `str`): + Label strings to be mapped to class ids. + + Returns: + `list` of `int`: + Class ids to be processed by pipeline. + """ + + if not isinstance(label, list): + label = list(label) + + for l in label: + if l not in self.labels: + raise ValueError( + f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[l] for l in label] + + @torch.no_grad() + def __call__( + self, + class_labels: list[int], + guidance_scale: float = 4.0, + generator: torch.Generator | list[torch.Generator] | None = None, + num_inference_steps: int = 50, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> ImagePipelineOutput | tuple: + r""" + The call function to the pipeline for generation. + + Args: + class_labels (list[int]): + list of ImageNet class labels for the images to be generated. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 250): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + ```py + >>> from diffusers import DiTPipeline, DPMSolverMultistepScheduler + >>> import torch + + >>> pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe = pipe.to("cuda") + + >>> # pick words from Imagenet class labels + >>> pipe.labels # to print all available words + + >>> # pick words that exist in ImageNet + >>> words = ["white shark", "umbrella"] + + >>> class_ids = pipe.get_label_ids(words) + + >>> generator = torch.manual_seed(33) + >>> output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) + + >>> image = output.images[0] # label 'white shark' + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + batch_size = len(class_labels) + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + device=self._execution_device, + dtype=self.transformer.dtype, + ) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1) + class_null = torch.tensor([1000] * batch_size, device=self._execution_device) + class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1: + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = torch.cat([half, half], dim=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timesteps = t + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(timesteps, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(latent_model_input.shape[0]) + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ).sample + + # perform guidance + if guidance_scale > 1: + eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + + noise_pred = torch.cat([eps, rest], dim=1) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + model_output, _ = torch.split(noise_pred, latent_channels, dim=1) + else: + model_output = noise_pred + + # compute previous image: x_t -> x_t-1 + latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + if guidance_scale > 1: + latents, _ = latent_model_input.chunk(2, dim=0) + else: + latents = latent_model_input + + latents = 1 / self.vae.config.scaling_factor * latents + samples = self.vae.decode(latents).sample + + samples = (samples / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + samples = self.numpy_to_pil(samples) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (samples,) + + return ImagePipelineOutput(images=samples) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49923423f951c1e6124ddb9b0506f740f28a7a5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"] + _import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"] + _import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_easyanimate import EasyAnimatePipeline + from .pipeline_easyanimate_control import EasyAnimateControlPipeline + from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84fe80a70dc5f29e481816bf5214d836ee89912d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ad447c5333014801307cdab9a5f0f5dee13ef9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate_control.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate_control.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..855d4159023ee7c665dd84e5263dba543f34b0ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate_control.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate_inpaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate_inpaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09c2c28b0b081470655d371896e64ec581bce9ad Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_easyanimate_inpaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d74a03a7a89f0190e6d7563e77c0742a13352b2e Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec8f44e6d1af6f0f541f751ec06152c7ae8505a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate.py @@ -0,0 +1,774 @@ +# Copyright 2025 The EasyAnimate team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import ( + BertModel, + BertTokenizer, + Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import EasyAnimatePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import EasyAnimatePipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh" + >>> pipe = EasyAnimatePipeline.from_pretrained( + ... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16 + ... ).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> sample_size = (512, 512) + >>> video = pipe( + ... prompt=prompt, + ... guidance_scale=6, + ... negative_prompt="bad detailed", + ... height=sample_size[0], + ... width=sample_size[1], + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimatePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: Qwen2VLForConditionalGeneration | BertModel, + tokenizer: Qwen2Tokenizer | BertTokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.enable_text_attention_mask = ( + self.transformer.config.enable_text_attention_mask + if getattr(self, "transformer", None) is not None + else True + ) + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + """ + dtype = dtype or self.text_encoder.dtype + device = device or self.text_encoder.device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } + for _prompt in prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } + for _negative_prompt in negative_prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_temporal_compression_ratio + 1, + height // self.vae_spatial_compression_ratio, + width // self.vae_spatial_compression_ratio, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 49, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + timesteps: list[int] | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. + + Examples: + prompt (`str` or `list[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + num_frames (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow + down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `list[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + Original dimensions of the output. + target_size (`tuple[int, int]`, *optional*): + Desired output dimensions for calculations. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates for cropping. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = self.transformer.dtype + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] + + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return EasyAnimatePipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py new file mode 100644 index 0000000000000000000000000000000000000000..5e07996a661c18313c6e4eee554c5017fbf97ca2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py @@ -0,0 +1,998 @@ +# Copyright 2025 The EasyAnimate team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + BertModel, + BertTokenizer, + Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import EasyAnimatePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import EasyAnimateControlPipeline + >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent + >>> from diffusers.utils import export_to_video, load_video + + >>> pipe = EasyAnimateControlPipeline.from_pretrained( + ... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> control_video = load_video( + ... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4" + ... ) + >>> prompt = ( + ... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. " + ... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. " + ... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, " + ... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. " + ... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. " + ... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each " + ... "releasing their fragrances, creating a relaxed and joyful atmosphere." + ... ) + >>> sample_size = (672, 384) + >>> num_frames = 49 + + >>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size) + >>> video = pipe( + ... prompt, + ... num_frames=num_frames, + ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.", + ... height=sample_size[0], + ... width=sample_size[1], + ... control_video=input_video, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +def preprocess_image(image, sample_size): + """ + Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor. + """ + if isinstance(image, torch.Tensor): + # If input is a tensor, assume it's in CHW format and resize using interpolation + image = torch.nn.functional.interpolate( + image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False + ).squeeze(0) + elif isinstance(image, Image.Image): + # If input is a PIL image, resize and convert to numpy array + image = image.resize((sample_size[1], sample_size[0])) + image = np.array(image) + elif isinstance(image, np.ndarray): + # If input is a numpy array, resize using PIL + image = Image.fromarray(image).resize((sample_size[1], sample_size[0])) + image = np.array(image) + else: + raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.") + + # Convert to tensor if not already + if not isinstance(image, torch.Tensor): + image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1] + + return image + + +def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None): + if input_video is not None: + # Convert each frame in the list to tensor + input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video] + + # Stack all frames into a single tensor (F, C, H, W) + input_video = torch.stack(input_video)[:num_frames] + + # Add batch dimension (B, F, C, H, W) + input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0) + + if validation_video_mask is not None: + # Handle mask input + validation_video_mask = preprocess_image(validation_video_mask, size=sample_size) + input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255) + + # Adjust mask dimensions to match video + input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 + else: + input_video, input_video_mask = None, None + + if ref_image is not None: + # Convert reference image to tensor + ref_image = preprocess_image(ref_image, size=sample_size) + ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W) + else: + ref_image = None + + return input_video, input_video_mask, ref_image + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Resize mask information in magvit +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False) + return resized_mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimateControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: Qwen2VLForConditionalGeneration | BertModel, + tokenizer: Qwen2Tokenizer | BertTokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.enable_text_attention_mask = ( + self.transformer.config.enable_text_attention_mask + if getattr(self, "transformer", None) is not None + else True + ) + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_spatial_compression_ratio, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + """ + dtype = dtype or self.text_encoder.dtype + device = device or self.text_encoder.device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } + for _prompt in prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } + for _negative_prompt in negative_prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_temporal_compression_ratio + 1, + height // self.vae_spatial_compression_ratio, + width // self.vae_spatial_compression_ratio, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim=0) + control = control * self.vae.config.scaling_factor + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim=0) + control_image_latents = control_image_latents * self.vae.config.scaling_factor + else: + control_image_latents = None + + return control, control_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 49, + height: int | None = 512, + width: int | None = 512, + control_video: torch.FloatTensor = None, + control_camera_video: torch.FloatTensor = None, + ref_image: torch.FloatTensor = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + timesteps: list[int] | None = None, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. + + Examples: + prompt (`str` or `list[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + num_frames (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow + down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `list[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = self.transformer.dtype + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + text_encoder_index=0, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + latents, + ) + + if control_camera_video is not None: + control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True) + control_video_latents = control_video_latents * 6 + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + elif control_video is not None: + batch_size, channels, num_frames, height_video, width_video = control_video.shape + control_video = self.image_processor.preprocess( + control_video.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height_video, width_video + ), + height=height, + width=width, + ) + control_video = control_video.to(dtype=torch.float32) + control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute( + 0, 2, 1, 3, 4 + ) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + )[1] + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + else: + control_video_latents = torch.zeros_like(latents).to(device, dtype) + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + + if ref_image is not None: + batch_size, channels, num_frames, height_video, width_video = ref_image.shape + ref_image = self.image_processor.preprocess( + ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video), + height=height, + width=width, + ) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + + ref_image_latents = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + )[1] + + ref_image_latents_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + ref_image_latents_conv_in[:, :, :1] = ref_image_latents + ref_image_latents_conv_in = ( + torch.cat([ref_image_latents_conv_in] * 2) + if self.do_classifier_free_guidance + else ref_image_latents_conv_in + ).to(device, dtype) + control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1) + else: + ref_image_latents_conv_in = torch.zeros_like(latents) + ref_image_latents_conv_in = ( + torch.cat([ref_image_latents_conv_in] * 2) + if self.do_classifier_free_guidance + else ref_image_latents_conv_in + ).to(device, dtype) + control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + control_latents=control_latents, + return_dict=False, + )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Convert to tensor + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return EasyAnimatePipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..8723138980081a8d04cb50bf174317f50acf1837 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py @@ -0,0 +1,1239 @@ +# Copyright 2025 The EasyAnimate team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + BertModel, + BertTokenizer, + Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import EasyAnimatePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import EasyAnimateInpaintPipeline + >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = EasyAnimateInpaintPipeline.from_pretrained( + ... "alibaba-pai/EasyAnimateV5.1-12b-zh-InP-diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> validation_image_start = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + + >>> validation_image_end = None + >>> sample_size = (448, 576) + >>> num_frames = 49 + >>> input_video, input_video_mask = get_image_to_video_latent( + ... [validation_image_start], validation_image_end, num_frames, sample_size + ... ) + + >>> video = pipe( + ... prompt, + ... num_frames=num_frames, + ... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.", + ... height=sample_size[0], + ... width=sample_size[1], + ... video=input_video, + ... mask_video=input_video_mask, + ... ) + >>> export_to_video(video.frames[0], "output.mp4", fps=8) + ``` +""" + + +def preprocess_image(image, sample_size): + """ + Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor. + """ + if isinstance(image, torch.Tensor): + # If input is a tensor, assume it's in CHW format and resize using interpolation + image = torch.nn.functional.interpolate( + image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False + ).squeeze(0) + elif isinstance(image, Image.Image): + # If input is a PIL image, resize and convert to numpy array + image = image.resize((sample_size[1], sample_size[0])) + image = np.array(image) + elif isinstance(image, np.ndarray): + # If input is a numpy array, resize using PIL + image = Image.fromarray(image).resize((sample_size[1], sample_size[0])) + image = np.array(image) + else: + raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.") + + # Convert to tensor if not already + if not isinstance(image, torch.Tensor): + image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1] + + return image + + +def get_image_to_video_latent(validation_image_start, validation_image_end, num_frames, sample_size): + """ + Generate latent representations for video from start and end images. Inputs can be PIL.Image, numpy.ndarray, or + torch.Tensor. + """ + input_video = None + input_video_mask = None + + if validation_image_start is not None: + # Preprocess the starting image(s) + if isinstance(validation_image_start, list): + image_start = [preprocess_image(img, sample_size) for img in validation_image_start] + else: + image_start = preprocess_image(validation_image_start, sample_size) + + # Create video tensor from the starting image(s) + if isinstance(image_start, list): + start_video = torch.cat( + [img.unsqueeze(1).unsqueeze(0) for img in image_start], + dim=2, + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, num_frames, 1, 1]) + input_video[:, :, : len(image_start)] = start_video + else: + input_video = torch.tile( + image_start.unsqueeze(1).unsqueeze(0), + [1, 1, num_frames, 1, 1], + ) + + # Normalize input video (already normalized in preprocess_image) + + # Create mask for the input video + input_video_mask = torch.zeros_like(input_video[:, :1]) + if isinstance(image_start, list): + input_video_mask[:, :, len(image_start) :] = 255 + else: + input_video_mask[:, :, 1:] = 255 + + # Handle ending image(s) if provided + if validation_image_end is not None: + if isinstance(validation_image_end, list): + image_end = [preprocess_image(img, sample_size) for img in validation_image_end] + end_video = torch.cat( + [img.unsqueeze(1).unsqueeze(0) for img in image_end], + dim=2, + ) + input_video[:, :, -len(end_video) :] = end_video + input_video_mask[:, :, -len(image_end) :] = 0 + else: + image_end = preprocess_image(validation_image_end, sample_size) + input_video[:, :, -1:] = image_end.unsqueeze(1).unsqueeze(0) + input_video_mask[:, :, -1:] = 0 + + elif validation_image_start is None: + # If no starting image is provided, initialize empty tensors + input_video = torch.zeros([1, 3, num_frames, sample_size[0], sample_size[1]]) + input_video_mask = torch.ones([1, 1, num_frames, sample_size[0], sample_size[1]]) * 255 + + return input_video, input_video_mask + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Resize mask information in magvit +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False) + return resized_mask + + +## Add noise to reference video +def add_noise_to_reference_video(image, ratio=None, generator=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + if generator is not None: + image_noise = ( + torch.randn(image.size(), generator=generator, dtype=image.dtype, device=image.device) + * sigma[:, None, None, None, None] + ) + else: + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image == -1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimateInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using EasyAnimate. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + + Args: + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLMagvit, + text_encoder: Qwen2VLForConditionalGeneration | BertModel, + tokenizer: Qwen2Tokenizer | BertTokenizer, + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.enable_text_attention_mask = ( + self.transformer.config.enable_text_attention_mask + if getattr(self, "transformer", None) is not None + else True + ) + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_spatial_compression_ratio, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + # Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + """ + dtype = dtype or self.text_encoder.dtype + device = device or self.text_encoder.device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + if isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } + for _prompt in prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } + for _negative_prompt in negative_prompt + ] + text = [ + self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages + ] + + text_inputs = self.tokenizer( + text=text, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(self.text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + if mask is not None: + mask = mask.to(device=device, dtype=dtype) + new_mask = [] + bs = 1 + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim=0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video( + masked_image, ratio=noise_aug_strength, generator=generator + ) + new_mask_pixel_values = [] + bs = 1 + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim=0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + else: + masked_image_latents = None + + return mask, masked_image_latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_temporal_compression_ratio + 1, + height // self.vae_spatial_compression_ratio, + width // self.vae_spatial_compression_ratio, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=dtype) + bs = 1 + new_video = [] + for i in range(0, video.shape[0], bs): + video_bs = video[i : i + bs] + video_bs = self.vae.encode(video_bs)[0] + video_bs = video_bs.sample() + new_video.append(video_bs) + video = torch.cat(new_video, dim=0) + video = video * self.vae.config.scaling_factor + + video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) + video_latents = video_latents.to(device=device, dtype=dtype) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + latents = noise if is_strength_max else self.scheduler.scale_noise(video_latents, timestep, noise) + else: + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + if hasattr(self.scheduler, "init_noise_sigma"): + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_frames: int | None = 49, + video: torch.FloatTensor = None, + mask_video: torch.FloatTensor = None, + masked_video_latents: torch.FloatTensor = None, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + strength: float = 1.0, + noise_aug_strength: float = 0.0563, + timesteps: list[int] | None = None, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Examples: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + num_frames (`int`, *optional*): + Length of the video to be generated in seconds. This parameter influences the number of frames and + continuity of generated content. + video (`torch.FloatTensor`, *optional*): + A tensor representing an input video, which can be modified depending on the prompts provided. + mask_video (`torch.FloatTensor`, *optional*): + A tensor to specify areas of the video to be masked (omitted from generation). + masked_video_latents (`torch.FloatTensor`, *optional*): + Latents from masked portions of the video, utilized during image generation. + height (`int`, *optional*): + The height in pixels of the generated image or video frames. + width (`int`, *optional*): + The width in pixels of the generated image or video frames. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image but slower + inference time. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to exclude in image generation. If not defined, you need to provide + `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + A parameter defined in the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies to the + [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the + inference process. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting + random seeds which helps in making generation deterministic. + latents (`torch.Tensor`, *optional*): + A pre-computed latent representation which can be used to guide the generation process. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the + outputs. If not provided, embeddings are generated from the `negative_prompt` argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using + `prompt_embeds`. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used. + output_type (`str`, *optional*, defaults to `"latent"`): + The output format of the generated image. Choose between `PIL.Image` and `np.array` to define how you + want the results to be formatted. + return_dict (`bool`, *optional*, defaults to `True`): + If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned; + otherwise, a tuple containing the generated images and safety flags will be returned. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, + *optional*): + A callback function (or a list of them) that will be executed at the end of each denoising step, + allowing for custom processing during generation. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + Specifies which tensor inputs should be included in the callback function. If not defined, all tensor + inputs will be passed, facilitating enhanced logging or monitoring of the generation process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + strength (`float`, *optional*, defaults to 1.0): + Affects the overall styling or quality of the generated output. Values closer to 1 usually provide + direct adherence to prompts. + + Examples: + # Example usage of the function for generating images based on prompts. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + Returns either a structured output containing generated images and their metadata when `return_dict` is + `True`, or a simpler tuple, where the first element is a list of generated images and the second + element indicates if any of them contain "not-safe-for-work" (NSFW) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int(height // 16 * 16) + width = int(width // 16 * 16) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = self.transformer.dtype + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 + ) + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + if video is not None: + batch_size, channels, num_frames, height_video, width_video = video.shape + init_video = self.image_processor.preprocess( + video.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video), + height=height, + width=width, + ) + init_video = init_video.to(dtype=torch.float32) + init_video = init_video.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + else: + init_video = None + + # Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == num_channels_latents + + # 5. Prepare latents. + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_image_latents, + ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 6. Prepare inpaint latents if it needs. + if mask_video is not None: + if (mask_video == 255).all(): + mask = torch.zeros_like(latents).to(device, dtype) + # Use zero latents if we want to t2v. + if self.transformer.config.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype) + else: + mask_latents = torch.zeros_like(latents).to(device, dtype) + masked_video_latents = torch.zeros_like(latents).to(device, dtype) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + # Prepare mask latent variables + batch_size, channels, num_frames, height_video, width_video = mask_video.shape + mask_condition = self.mask_processor.preprocess( + mask_video.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height_video, width_video + ), + height=height, + width=width, + ) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = mask_condition.reshape(batch_size, num_frames, channels, height, width).permute( + 0, 2, 1, 3, 4 + ) + + if num_channels_transformer != num_channels_latents: + mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) + if masked_video_latents is None: + masked_video = ( + init_video * (mask_condition_tile < 0.5) + + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + ) + else: + masked_video = masked_video_latents + + if self.transformer.config.resize_inpaint_mask_directly: + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + mask_latents = resize_mask( + 1 - mask_condition, masked_video_latents, self.vae.config.cache_mag_vae + ) + mask_latents = mask_latents.to(device, dtype) * self.vae.config.scaling_factor + else: + mask_latents, masked_video_latents = self.prepare_mask_latents( + mask_condition_tile, + masked_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) + if self.do_classifier_free_guidance + else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + inpaint_latents = None + + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to( + device, dtype + ) + else: + if num_channels_transformer != num_channels_latents: + mask = torch.zeros_like(latents).to(device, dtype) + if self.transformer.config.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype) + else: + mask_latents = torch.zeros_like(latents).to(device, dtype) + masked_video_latents = torch.zeros_like(latents).to(device, dtype) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + mask = torch.zeros_like(init_video[:, :1]) + mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode="trilinear", align_corners=True).to( + device, dtype + ) + + inpaint_latents = None + + # Check that sizes of mask, masked image and latents match + if num_channels_transformer != num_channels_latents: + num_channels_mask = mask_latents.shape[1] + num_channels_masked_image = masked_video_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.transformer.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.transformer` or your `mask_image` or `image` input." + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + inpaint_latents=inpaint_latents, + return_dict=False, + )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_transformer == num_channels_latents: + init_latents_proper = image_latents + init_mask = mask + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep], noise) + ) + else: + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return EasyAnimatePipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd904ae7dfeb3e842cbb62844df643c7e5b9889 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/easyanimate/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class EasyAnimatePipelineOutput(BaseOutput): + r""" + Output class for EasyAnimate pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea25c148e2f1d09f3a7fcf124020d6d24bd76935 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/__init__.py @@ -0,0 +1,73 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_flux"] = ["ReduxImageEncoder"] + _import_structure["pipeline_flux"] = ["FluxPipeline"] + _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"] + _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"] + _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"] + _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"] + _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"] + _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"] + _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"] + _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"] + _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"] + _import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"] + _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"] + _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modeling_flux import ReduxImageEncoder + from .pipeline_flux import FluxPipeline + from .pipeline_flux_control import FluxControlPipeline + from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline + from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline + from .pipeline_flux_controlnet import FluxControlNetPipeline + from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline + from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline + from .pipeline_flux_fill import FluxFillPipeline + from .pipeline_flux_img2img import FluxImg2ImgPipeline + from .pipeline_flux_inpaint import FluxInpaintPipeline + from .pipeline_flux_kontext import FluxKontextPipeline + from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline + from .pipeline_flux_prior_redux import FluxPriorReduxPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/modeling_flux.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/modeling_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..916e3fbe59533d39af85cba3adeae3c771726b02 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/modeling_flux.py @@ -0,0 +1,46 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput + + +@dataclass +class ReduxImageEncoderOutput(BaseOutput): + image_embeds: torch.Tensor | None = None + + +class ReduxImageEncoder(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + ) -> None: + super().__init__() + + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3) + self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features) + + def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput: + projected_x = self.redux_down(nn.functional.silu(self.redux_up(x))) + + return ReduxImageEncoderOutput(image_embeds=projected_x) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..be2bbe2acc6a6fe28866309a4ba284208e634eae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux.py @@ -0,0 +1,1020 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control.py new file mode 100644 index 0000000000000000000000000000000000000000..84e21b59ce5824103f79a1d482e17877f016b9bd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control.py @@ -0,0 +1,915 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from controlnet_aux import CannyDetector + >>> from diffusers import FluxControlPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) + + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) + + >>> image = pipe( + ... prompt=prompt, + ... control_image=control_image, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for controllable text-to-image generation with image conditions. + + Reference: https://bfl.ai/flux-1-tools + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.vae_latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae_latent_channels + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with prompt at the expense of lower image quality. + + Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..b455c611e0ae527e25ff95899b70dd71dee6944d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -0,0 +1,945 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from controlnet_aux import CannyDetector + >>> from diffusers import FluxControlImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxControlImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png" + ... ) + + >>> processor = CannyDetector() + >>> control_image = processor( + ... control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 + ... ) + + >>> image = pipe( + ... prompt=prompt, + ... image=image, + ... control_image=control_image, + ... strength=0.8, + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=30.0, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1e05493a11139aa2003a45eee3e9e33109e912 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -0,0 +1,1168 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +) +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + import torch + from diffusers import FluxControlInpaintPipeline + from diffusers.models.transformers import FluxTransformer2DModel + from transformers import T5EncoderModel + from diffusers.utils import load_image, make_image_grid + from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux + from PIL import Image + import numpy as np + + pipe = FluxControlInpaintPipeline.from_pretrained( + "black-forest-labs/FLUX.1-Depth-dev", + torch_dtype=torch.bfloat16, + ) + # use following lines if you have GPU constraints + # --------------------------------------------------------------- + transformer = FluxTransformer2DModel.from_pretrained( + "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16 + ) + text_encoder_2 = T5EncoderModel.from_pretrained( + "sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + ) + pipe.transformer = transformer + pipe.text_encoder_2 = text_encoder_2 + pipe.enable_model_cpu_offload() + # --------------------------------------------------------------- + pipe.to("cuda") + + prompt = "a blue robot singing opera with human-like expressions" + image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png") + + head_mask = np.zeros_like(image) + head_mask[65:580, 300:642] = 255 + mask_image = Image.fromarray(head_mask) + + processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf") + control_image = processor(image)[0].convert("RGB") + + output = pipe( + prompt=prompt, + image=image, + control_image=control_image, + mask_image=mask_image, + num_inference_steps=30, + strength=0.9, + guidance_scale=10.0, + generator=torch.Generator().manual_seed(42), + ).images[0] + make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save( + "output.png" + ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlInpaintPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux pipeline for image inpainting using Flux-dev-Depth/Canny. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_mask_latents( + self, + image, + mask_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + image = self.image_processor.preprocess(image, height=height, width=width) + mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = image * (1 - mask_image) + masked_image = masked_image.to(device=device, dtype=dtype) + + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask_image = torch.nn.functional.interpolate(mask_image, size=(height, width)) + mask_image = mask_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == num_channels_latents: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask_image.shape[0] < batch_size: + if not batch_size % mask_image.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask_image.shape[0]} mask_image were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask_image = self._pack_latents( + mask_image.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + masked_image_latents = torch.cat((masked_image_latents, mask_image), dim=-1) + + return mask_image, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `list[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will be generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + device = self._execution_device + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Preprocess mask and image + num_channels_latents = self.vae.config.latent_channels + if masked_image_latents is not None: + # pre computed masked_image_latents and mask_image + masked_image_latents = masked_image_latents.to(latents.device) + mask = mask_image.to(latents.device) + else: + mask, masked_image_latents = self.prepare_mask_latents( + image, + mask_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 8 + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + if control_image.ndim == 4: + control_image = self.vae.encode(control_image).latent_dist.sample(generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height_8 = 2 * (int(height) // (self.vae_scale_factor * 2)) + width_8 = 2 * (int(width) // (self.vae_scale_factor * 2)) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents, control_image], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_mask = mask + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + image_latents, torch.tensor([noise_timestep]), noise + ) + else: + init_latents_proper = image_latents + init_latents_proper = self._pack_latents( + init_latents_proper, batch_size * num_images_per_prompt, num_channels_latents, height_8, width_8 + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d8dcdfcd46403ef6d625181e146094dc2aac8a12 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -0,0 +1,1183 @@ +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image + >>> from diffusers import FluxControlNetPipeline + >>> from diffusers import FluxControlNetModel + + >>> base_model = "black-forest-labs/FLUX.1-dev" + >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" + >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + >>> pipe = FluxControlNetPipeline.from_pretrained( + ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: FluxControlNetModel + | list[FluxControlNetModel] + | tuple[FluxControlNetModel] + | FluxMultiControlNetModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_image: PipelineImageInput = None, + control_mode: int | list[int] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_mode (`int` or `list[int]`,, *optional*, defaults to None): + The control mode when applying ControlNet-Union. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Prepare text embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # Here we ensure that `control_mode` has the same length as the control_image. + if control_mode is not None: + if not isinstance(control_mode, int): + raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`") + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + control_images.append(control_image_) + + control_image = control_images + + # Here we ensure that `control_mode` has the same length as the control_image. + if isinstance(control_mode, list) and len(control_mode) != len(control_image): + raise ValueError( + "For Multi-ControlNet, `control_mode` must be a list of the same " + + " length as the number of controlnets (control images) specified" + ) + if not isinstance(control_mode, list): + control_mode = [control_mode] * len(control_image) + # set control mode + control_modes = [] + for cmode in control_mode: + if cmode is None: + cmode = -1 + control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long) + control_modes.append(control_mode) + control_mode = control_modes + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..fdaff9b0af8a38fa9dfd9af2b0ca52b7cbd92859 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -0,0 +1,1009 @@ +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetImg2ImgPipeline, FluxControlNetModel + >>> from diffusers.utils import load_image + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-Controlnet-Canny-alpha", torch_dtype=torch.bfloat16 + ... ) + + >>> pipe = FluxControlNetImg2ImgPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.text_encoder.to(torch.float16) + >>> pipe.controlnet.to(torch.float16) + >>> pipe.to("cuda") + + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + ... ) + + >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=1.0, + ... strength=0.7, + ... num_inference_steps=2, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for image-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: FluxControlNetModel + | list[FluxControlNetModel] + | tuple[FluxControlNetModel] + | FluxMultiControlNetModel, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + pooled_prompt_embeds=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_mode: int | list[int] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to modify with the pipeline. + control_image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). + control_mode (`int` or `list[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + num_channels_latents = self.transformer.config.in_channels // 4 + + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + if self.controlnet.nets[0].input_hint_block is None: + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + guidance = ( + torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None + ) + guidance = guidance.expand(latents.shape[0]) if guidance is not None else None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..eed671152bc9b88fa6374e3fa603968c39ef5d91 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -0,0 +1,1210 @@ +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxControlNetInpaintPipeline + >>> from diffusers.models import FluxControlNetModel + >>> from diffusers.utils import load_image + + >>> controlnet = FluxControlNetModel.from_pretrained( + ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16 + ... ) + >>> pipe = FluxControlNetInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg" + ... ) + >>> init_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) + + >>> prompt = "A girl holding a sign that says InstantX" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... control_guidance_start=0.2, + ... control_guidance_end=0.8, + ... controlnet_conditioning_scale=0.7, + ... strength=0.7, + ... num_inference_steps=28, + ... guidance_scale=3.5, + ... ).images[0] + >>> image.save("flux_controlnet_inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): + r""" + The Flux controlnet pipeline for inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image", "mask", "masked_image_latents"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + controlnet: FluxControlNetModel + | list[FluxControlNetModel] + | tuple[FluxControlNetModel] + | FluxMultiControlNetModel, + ): + super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = FluxMultiControlNetModel(controlnet) + + self.register_modules( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + controlnet=controlnet, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.6, + padding_mask_crop: int | None = None, + sigmas: list[float] | None = None, + num_inference_steps: int = 28, + guidance_scale: float = 7.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_mode: int | list[int] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.FloatTensor`): + The image(s) to inpaint. + mask_image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.FloatTensor`): + The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels + will be preserved. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-generated masked image latents. + control_image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.FloatTensor`): + The ControlNet input condition. Image to control the generation. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.6): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. + padding_mask_crop (`int`, *optional*): + The size of the padding to use when cropping the mask. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + control_mode (`int` or `list[int]`, *optional*): + The mode for the ControlNet. If multiple ControlNets are used, this should be a list. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original transformer. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to + make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the joint attention mechanism. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising step during the inference. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, *optional*, defaults to 512): + The maximum length of the sequence to be generated. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + global_height = height + global_width = width + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.transformer.dtype + + # 3. Encode input prompt + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region( + mask_image, global_width, global_height, pad=padding_mask_crop + ) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, FluxControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # set control mode + if control_mode is not None: + control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + elif isinstance(self.controlnet, FluxMultiControlNetModel): + control_images = [] + + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image_.shape[-2:] + + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) + + control_images.append(control_image_) + + control_image = control_images + + # set control mode + control_mode_ = [] + if isinstance(control_mode, list): + for cmode in control_mode: + if cmode is None: + control_mode_.append(-1) + else: + control_mode_.append(cmode) + control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) + control_mode = control_mode.reshape([-1, 1]) + + # 6. Prepare timesteps + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare mask latents + mask_condition = self.mask_processor.preprocess( + mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords + ) + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + global_height, + global_width, + prompt_embeds.dtype, + device, + generator, + ) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # predict the noise residual + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + if use_guidance: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + controlnet_block_samples, controlnet_single_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + controlnet_mode=control_mode, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + ) + + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # For inpainting, we need to apply the mask and add the masked image latents + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + control_image = callback_outputs.pop("control_image", control_image) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Post-processing + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py new file mode 100644 index 0000000000000000000000000000000000000000..cf929f53fc6dc01fd3a604a6874e9af46849abe0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -0,0 +1,1066 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxFillPipeline + >>> from diffusers.utils import load_image + + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png") + >>> mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png") + + >>> pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU + + >>> image = pipe( + ... prompt="a white paper cup", + ... image=image, + ... mask_image=mask, + ... height=1632, + ... width=1232, + ... guidance_scale=30, + ... num_inference_steps=50, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0] + >>> image.save("flux_fill.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxFillPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The Flux Fill pipeline for image inpainting/outpainting. + + Reference: https://blackforestlabs.ai/flux-1-tools/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # 1. calculate the height and width of the latents + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + # 2. encode the masked image + if masked_image.shape[1] == num_channels_latents: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + batch_size = batch_size * num_images_per_prompt + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # 4. pack the masked_image_latents + # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4 + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + + # 5.resize mask to latents shape we we concatenate the mask to the latents + mask = mask[:, 0, :, :] # batch_size, 8 * height, 8 * width (mask has not been 8x compressed) + mask = mask.view( + batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor + ) # batch_size, height, 8, width, 8 + mask = mask.permute(0, 2, 4, 1, 3) # batch_size, 8, 8, height, width + mask = mask.reshape( + batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width + ) # batch_size, 8*8, height, width + + # 6. pack the mask: + # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2 + mask = self._pack_latents( + mask, + batch_size, + self.vae_scale_factor * self.vae_scale_factor, + height, + width, + ) + mask = mask.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + image=None, + mask_image=None, + masked_image_latents=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + if image is not None and masked_image_latents is not None: + raise ValueError( + "Please provide either `image` or `masked_image_latents`, `masked_image_latents` should not be passed." + ) + + if image is not None and mask_image is None: + raise ValueError("Please provide `mask_image` when passing `image`.") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: torch.FloatTensor | None = None, + mask_image: torch.FloatTensor | None = None, + masked_image_latents: torch.FloatTensor | None = None, + height: int | None = None, + width: int | None = None, + strength: float = 1.0, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 30.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `list[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will be generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 30.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + image=image, + mask_image=mask_image, + masked_image_latents=masked_image_latents, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare prompt embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare mask and masked image latents + if masked_image_latents is not None: + masked_image_latents = masked_image_latents.to(latents.device) + else: + mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = init_image * (1 - mask_image) + masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) + + height, width = init_image.shape[-2:] + mask, masked_image_latents = self.prepare_mask_latents( + mask_image, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=torch.cat((latents, masked_image_latents), dim=2), + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-process the image + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..cadff7736ff442ebdca41cebab5736eb06d6e8fc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -0,0 +1,1096 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import FluxImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe( + ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + strength, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ce25a4f5a97ea9ed8ca8853c22a96c81465f53 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -0,0 +1,1211 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("flux_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterMixin): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = ( + masked_image_latents - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `list[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will be generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4.Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bbe42ef850b96f8d05f7121a2199f8cc5bf618 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -0,0 +1,1159 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxKontextPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxKontextPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors" + >>> image = pipe( + ... image=image, + ... prompt=prompt, + ... guidance_scale=2.5, + ... generator=torch.Generator().manual_seed(42), + ... ).images[0] + >>> image.save("output.png") + ``` +""" + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxKontextPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux Kontext pipeline for image-to-image and text-to-image generation. + + Reference: https://bfl.ai/announcements/flux-1-kontext-dev + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image: torch.Tensor | None, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = self._prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + + latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latent_ids, image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + max_area: int = 1024**2, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with prompt at the expense of lower image quality. + + Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): + Maximum sequence length to use with the `prompt`. + max_area (`int`, defaults to `1024 ** 2`): + The maximum area of the generated image in pixels. The height and width will be adjusted to fit this + area while maintaining the aspect ratio. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_height, original_width = height, width + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + image = self.image_processor.preprocess(image, image_height, image_width) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, latent_ids, image_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..313682dc7e335de860c2d0ecfc55f25c1ae63b6b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -0,0 +1,1485 @@ +# Copyright 2025 ZenAI. All rights reserved. +# author: @vuongminh1907 + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import FluxPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + # Inpainting with text only + ```py + >>> import torch + >>> from diffusers import FluxKontextInpaintPipeline + >>> from diffusers.utils import load_image + + >>> prompt = "Change the yellow dinosaur to green one" + >>> img_url = ( + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true" + ... ) + >>> mask_url = ( + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true" + ... ) + + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + + >>> pipe = FluxKontextInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0] + >>> image.save("kontext_inpainting_normal.png") + ``` + + # Inpainting with image conditioning + ```py + >>> import torch + >>> from diffusers import FluxKontextInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = FluxKontextInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "Replace this ball" + >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + >>> mask_url = ( + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true" + ... ) + >>> image_reference_url = ( + ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s" + ... ) + + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image_reference = load_image(image_reference_url) + + >>> mask = pipe.mask_processor.blur(mask, blur_factor=12) + >>> image = pipe( + ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0 + ... ).images[0] + >>> image.save("kontext_inpainting_ref.png") + ``` +""" + +PREFERRED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class FluxKontextInpaintPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Flux Kontext pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image: torch.Tensor | None, + timestep: int, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + image_reference: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + + # Prepare image latents + image_latents = image_ids = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + # Prepare image reference latents + image_reference_latents = image_reference_ids = None + if image_reference is not None: + image_reference = image_reference.to(device=device, dtype=dtype) + if image_reference.shape[1] != self.latent_channels: + image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator) + else: + image_reference_latents = image_reference + if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_reference_latents.shape[0] + image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_reference_latents = torch.cat([image_reference_latents], dim=0) + + latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device=device, dtype=dtype) + latents = noise + + image_latent_height, image_latent_width = image_latents.shape[2:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + image_ids = self._prepare_latent_image_ids( + batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + + if image_reference_latents is not None: + image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:] + image_reference_latents = self._pack_latents( + image_reference_latents, + batch_size, + num_channels_latents, + image_reference_latent_height, + image_reference_latent_width, + ) + image_reference_ids = self._prepare_latent_image_ids( + batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype + ) + # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_reference_ids[..., 0] = 1 + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise + + # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = ( + masked_image_latents - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + image_reference: PipelineImageInput | None = None, + mask_image: PipelineImageInput = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + strength: float = 1.0, + padding_mask_crop: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[torch.Tensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + max_area: int = 1024**2, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image + to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point for the + masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If + it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is + a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can + also accept image latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): + Maximum sequence length to use with the `prompt`. + max_area (`int`, defaults to `1024 ** 2`): + The maximum area of the generated image in pixels. The height and width will be adjusted to fit this + area while maintaining the aspect ratio. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_height, original_width = height, width + aspect_ratio = width / height + width = round((max_area * aspect_ratio) ** 0.5) + height = round((max_area / aspect_ratio) ** 0.5) + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + if height != original_height or width != original_width: + logger.warning( + f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + image = torch.cat(image, dim=0) + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + aspect_ratio = image_width / image_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_width, image_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + + # Choose the resolution of the image to be the same as the image + width = image_width + height = image_height + + # 2.1 Preprocess mask + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + image = self.image_processor.preprocess( + image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode + ) + else: + raise ValueError("image must be provided correctly for inpainting") + + init_image = image.to(dtype=torch.float32) + + # 2.1 Preprocess image_reference + if image_reference is not None and not ( + isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels + ): + if ( + isinstance(image_reference, list) + and isinstance(image_reference[0], torch.Tensor) + and image_reference[0].ndim == 4 + ): + image_reference = torch.cat(image_reference, dim=0) + img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference + image_reference_height, image_reference_width = self.image_processor.get_default_height_width( + img_reference + ) + aspect_ratio = image_reference_width / image_reference_height + if _auto_resize: + # Kontext is trained on specific resolutions, using one of them is recommended + _, image_reference_width, image_reference_height = min( + (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS + ) + image_reference_width = image_reference_width // multiple_of * multiple_of + image_reference_height = image_reference_height // multiple_of * multiple_of + image_reference = self.image_processor.resize( + image_reference, image_reference_height, image_reference_width + ) + image_reference = self.image_processor.preprocess( + image_reference, + image_reference_height, + image_reference_width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + else: + image_reference = None + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = ( + self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image_reference, + ) + ) + + if image_reference_ids is not None: + latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension + elif image_ids is not None: + latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask_condition < 0.5) + + mask, _ = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + latent_model_input = latents + if image_reference_latents is not None: + latent_model_input = torch.cat([latents, image_reference_latents], dim=1) + elif image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_prior_redux.py new file mode 100644 index 0000000000000000000000000000000000000000..330e2623b287dec2a6e157d065ecc7ae12111937 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -0,0 +1,490 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from PIL import Image +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ..pipeline_utils import DiffusionPipeline +from .modeling_flux import ReduxImageEncoder +from .pipeline_output import FluxPriorReduxPipelineOutput + + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPriorReduxPipeline, FluxPipeline + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> dtype = torch.bfloat16 + + >>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev" + >>> repo_base = "black-forest-labs/FLUX.1-dev" + >>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device) + >>> pipe = FluxPipeline.from_pretrained( + ... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16 + ... ).to(device) + + >>> image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png" + ... ) + >>> pipe_prior_output = pipe_prior_redux(image) + >>> images = pipe( + ... guidance_scale=2.5, + ... num_inference_steps=50, + ... generator=torch.Generator("cpu").manual_seed(0), + ... **pipe_prior_output, + ... ).images + >>> images[0].save("flux-redux.png") + ``` +""" + + +class FluxPriorReduxPipeline(DiffusionPipeline): + r""" + The Flux Redux pipeline for image-to-image generation. + + Reference: https://blackforestlabs.ai/flux-1-tools/ + + Args: + image_encoder ([`SiglipVisionModel`]): + SIGLIP vision model to encode the input image. + feature_extractor ([`SiglipImageProcessor`]): + Image processor for preprocessing images for the SIGLIP model. + image_embedder ([`ReduxImageEncoder`]): + Redux image encoder to process the SIGLIP embeddings. + text_encoder ([`CLIPTextModel`], *optional*): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`], *optional*): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`, *optional*): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`, *optional*): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "image_encoder->image_embedder" + _optional_components = [ + "text_encoder", + "tokenizer", + "text_encoder_2", + "tokenizer_2", + ] + _callback_tensor_inputs = [] + + def __init__( + self, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image_embedder: ReduxImageEncoder, + text_encoder: CLIPTextModel = None, + tokenizer: CLIPTokenizer = None, + text_encoder_2: T5EncoderModel = None, + tokenizer_2: T5TokenizerFast = None, + ): + super().__init__() + + self.register_modules( + image_encoder=image_encoder, + feature_extractor=feature_extractor, + image_embedder=image_embedder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + + def check_inputs( + self, + image, + prompt, + prompt_2, + prompt_embeds=None, + pooled_prompt_embeds=None, + prompt_embeds_scale=1.0, + pooled_prompt_embeds_scale=1.0, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)): + raise ValueError( + f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images" + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if isinstance(prompt_embeds_scale, list) and ( + isinstance(image, list) and len(prompt_embeds_scale) != len(image) + ): + raise ValueError( + f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" + ) + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + image = self.feature_extractor.preprocess( + images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True + ) + image = image.to(device=device, dtype=dtype) + + image_enc_hidden_states = self.image_encoder(**image).last_hidden_state + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + return image_enc_hidden_states + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + prompt_embeds_scale: float | list[float] | None = 1.0, + pooled_prompt_embeds_scale: float | list[float] | None = 1.0, + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. **experimental feature**: to use this feature, + make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders + are not loaded. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`: + [`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + prompt, + prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_embeds_scale=prompt_embeds_scale, + pooled_prompt_embeds_scale=pooled_prompt_embeds_scale, + ) + + # 2. Define call parameters + if image is not None and isinstance(image, Image.Image): + batch_size = 1 + elif image is not None and isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + if prompt is not None and isinstance(prompt, str): + prompt = batch_size * [prompt] + if isinstance(prompt_embeds_scale, float): + prompt_embeds_scale = batch_size * [prompt_embeds_scale] + if isinstance(pooled_prompt_embeds_scale, float): + pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale] + + device = self._execution_device + + # 3. Prepare image embeddings + image_latents = self.encode_image(image, device, 1) + + image_embeds = self.image_embedder(image_latents).image_embeds + image_embeds = image_embeds.to(device=device) + + # 3. Prepare (dummy) text embeddings + if hasattr(self, "text_encoder") and self.text_encoder is not None: + ( + prompt_embeds, + pooled_prompt_embeds, + _, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=1, + max_sequence_length=512, + lora_scale=None, + ) + else: + if prompt is not None: + logger.warning( + "prompt input is ignored when text encoders are not loaded to the pipeline. " + "Make sure to explicitly load the text encoders to enable prompt input. " + ) + # max_sequence_length is 512, t5 encoder hidden size is 4096 + prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype) + # pooled_prompt_embeds is 768, clip text encoder hidden size + pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype) + + # scale & concatenate image and text embeddings + prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1) + + prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None] + pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[ + :, None + ] + + # weighted sum + prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True) + pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prompt_embeds, pooled_prompt_embeds) + + return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..7f24bdbf5d60da1f0bd44971d10616ddc3dc655e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_output.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from ...utils import BaseOutput + + +@dataclass +class FluxPipelineOutput(BaseOutput): + """ + Output class for Flux image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: list[PIL.Image.Image] | np.ndarray + + +@dataclass +class FluxPriorReduxPipelineOutput(BaseOutput): + """ + Output class for Flux Prior Redux pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + prompt_embeds: torch.Tensor + pooled_prompt_embeds: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e1d520663038b617400375b32d38e9371f7268 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["Flux2PipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] + _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flux2 import Flux2Pipeline + from .pipeline_flux2_klein import Flux2KleinPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d2156b0e6c89a44b177a0f8c5ffcdf25e28921 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/image_processor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/image_processor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee5f0fab24520392f256f1694b25c9f32b3a8dc7 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/image_processor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_flux2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_flux2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd5516268346f623b9e124c2d9a7b35571d69988 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_flux2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_flux2_klein.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_flux2_klein.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba7dc26ac98b7ef63c95b9bb71eb815cb863c08d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_flux2_klein.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..875136cdc0b467849b1907ae147b1d2a2d80a76c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/system_messages.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/system_messages.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..194e37d907bd19b0755efe55f25842c2400c1583 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/__pycache__/system_messages.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/image_processor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a1b80ce5336b3e99d86d537fae29038542a804 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/image_processor.py @@ -0,0 +1,177 @@ +# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import PIL.Image + +from ...configuration_utils import register_to_config +from ...image_processor import VaeImageProcessor + + +class Flux2ImageProcessor(VaeImageProcessor): + r""" + Image processor to preprocess the reference (character) image for the Flux2 model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `32`): + VAE latent channels. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_convert_rgb (`bool`, *optional*, defaults to be `True`): + Whether to convert the images to RGB format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_normalize: bool = True, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_normalize=do_normalize, + do_convert_rgb=do_convert_rgb, + ) + + @staticmethod + def check_image_input( + image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024 + ) -> PIL.Image.Image: + """ + Check if image meets minimum size and aspect ratio requirements. + + Args: + image: PIL Image to validate + max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width) + min_side_length: Minimum pixels required for width and height + max_area: Maximum allowed area in pixels² + + Returns: + The input image if valid + + Raises: + ValueError: If image is too small or aspect ratio is too extreme + """ + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}") + + width, height = image.size + + # Check minimum dimensions + if width < min_side_length or height < min_side_length: + raise ValueError( + f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px" + ) + + # Check aspect ratio + aspect_ratio = max(width / height, height / width) + if aspect_ratio > max_aspect_ratio: + raise ValueError( + f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). " + f"Maximum allowed ratio is {max_aspect_ratio}:1" + ) + + return image + + @staticmethod + def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + + scale = math.sqrt(target_area / (image_width * image_height)) + width = int(image_width * scale) + height = int(image_height * scale) + + return image.resize((width, height), PIL.Image.Resampling.LANCZOS) + + @staticmethod + def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image: + image_width, image_height = image.size + pixel_count = image_width * image_height + if pixel_count <= target_area: + return image + return Flux2ImageProcessor._resize_to_target_area(image, target_area) + + def _resize_and_crop( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + center crop the image to the specified width and height. + + Args: + image (`PIL.Image.Image`): + The image to resize and crop. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and cropped image. + """ + image_width, image_height = image.size + + left = (image_width - width) // 2 + top = (image_height - height) // 2 + right = left + width + bottom = top + height + + return image.crop((left, top, right, bottom)) + + # Taken from + # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19 + @staticmethod + def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image: + """ + Concatenate a list of PIL images horizontally with center alignment and white background. + """ + + # If only one image, return a copy of it + if len(images) == 1: + return images[0].copy() + + # Convert all images to RGB if not already + images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] + + # Calculate dimensions for horizontal concatenation + total_width = sum(img.width for img in images) + max_height = max(img.height for img in images) + + # Create new image with white background + background_color = (255, 255, 255) + new_img = PIL.Image.new("RGB", (total_width, max_height), background_color) + + # Paste images with center alignment + x_offset = 0 + for img in images: + y_offset = (max_height - img.height) // 2 + new_img.paste(img, (x_offset, y_offset)) + x_offset += img.width + + return new_img diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd0563fcc1905e746d9e945db5b2ab674c07e7a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2.py @@ -0,0 +1,1032 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput +from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + +UPSAMPLING_MAX_IMAGE_SIZE = 768**2 + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 +def format_input( + prompts: list[str], + system_message: str = SYSTEM_MESSAGE, + images: list[PIL.Image.Image, list[list[PIL.Image.Image]]] | None = None, +): + """ + Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images + to the input. + + Args: + prompts: List of text prompts + system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE) + images (optional): List of images to add to the input. + + Returns: + List of conversations, where each conversation is a list of message dicts + """ + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + if images is None or len(images) == 0: + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + else: + assert len(images) == len(prompts), "Number of images must match number of prompts" + messages = [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + ] + for _ in cleaned_txt + ] + + for i, (el, images) in enumerate(zip(messages, images)): + # optionally add the images per batch element. + if images is not None: + el.append( + { + "role": "user", + "content": [{"type": "image", "image": image_obj} for image_obj in images], + } + ) + # add the text. + el.append( + { + "role": "user", + "content": [{"type": "text", "text": cleaned_txt[i]}], + } + ) + + return messages + + +# Adapted from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19 +def _validate_and_process_images( + images: list[list[PIL.Image.Image]] | list[PIL.Image.Image], + image_processor: Flux2ImageProcessor, + upsampling_max_image_size: int, +) -> list[list[PIL.Image.Image]]: + # Simple validation: ensure it's a list of PIL images or list of lists of PIL images + if not images: + return [] + + # Check if it's a list of lists or a list of images + if isinstance(images[0], PIL.Image.Image): + # It's a list of images, convert to list of lists + images = [[im] for im in images] + + # potentially concatenate multiple images to reduce the size + images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images] + + # cap the pixels + images = [ + [image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i] + for img_i in images + ] + return images + + +# Taken from +# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251 +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): + Tokenizer of class + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2Transformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + self.system_message = SYSTEM_MESSAGE + self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I + self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I + self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + system_message: str = SYSTEM_MESSAGE, + hidden_states_layers: list[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (list[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def upsample_prompt( + self, + prompt: str | list[str], + images: list[PIL.Image.Image, list[list[PIL.Image.Image]]] = None, + temperature: float = 0.15, + device: torch.device = None, + ) -> list[str]: + prompt = [prompt] if isinstance(prompt, str) else prompt + device = self.text_encoder.device if device is None else device + + # Set system message based on whether images are provided + if images is None or len(images) == 0 or images[0] is None: + system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I + else: + system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I + + # Validate and process the input images + if images: + images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size) + + # Format input messages + messages_batch = format_input(prompts=prompt, system_message=system_message, images=images) + + # Process all messages at once + # with image processing a too short max length can throw an error in here. + inputs = self.tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=2048, + ) + + # Move to device + inputs["input_ids"] = inputs["input_ids"].to(device) + inputs["attention_mask"] = inputs["attention_mask"].to(device) + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype) + + # Generate text using the model's generate method + generated_ids = self.text_encoder.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=temperature, + use_cache=True, + ) + + # Decode only the newly generated tokens (skip input tokens) + # Extract only the generated portion + input_length = inputs["input_ids"].shape[1] + generated_tokens = generated_ids[:, input_length:] + + upsampled_prompt = self.tokenizer.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + return upsampled_prompt + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: list[PIL.Image.Image, PIL.Image.Image] | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (10, 20, 30), + caption_upsample_temperature: float = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + caption_upsample_temperature (`float`): + When specified, we will try to perform caption upsampling for potentially improved outputs. We + recommend setting it to 0.15 if caption upsampling is to be performed. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + if caption_upsample_temperature: + prompt = self.upsample_prompt( + prompt, images=image, temperature=caption_upsample_temperature, device=device + ) + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2_klein.py new file mode 100644 index 0000000000000000000000000000000000000000..936d2c3804ab840b13db9b5b56415cbef728ee0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -0,0 +1,918 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2KleinPipeline + + >>> pipe = Flux2KleinPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0] + >>> image.save("flux2_output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 Klein pipeline for text-to-image generation. + + Reference: + [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3ForCausalLM`]): + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM) + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + transformer: Flux2Transformer2DModel, + is_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + + self.register_to_config(is_distilled=is_distilled) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, + max_sequence_length: int = 512, + hidden_states_layers: list[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: torch.Tensor | None = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (list[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self.config.is_distilled + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: list[PIL.Image.Image] | PIL.Image.Image | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: str | list[str] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: tuple[int] = (9, 18, 27), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + guidance_scale=guidance_scale, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..34ae9b574453349fdf42c2e98680f1cee315fa0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/pipeline_output.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/system_messages.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/system_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..ecdb1371f0d49afa94788e0f02f89837469a706d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/flux2/system_messages.py @@ -0,0 +1,33 @@ +# docstyle-ignore +""" +These system prompts come from: +https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54 +""" + +# docstyle-ignore +SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object +attribution and actions without speculation.""" + +# docstyle-ignore +SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent. + +Guidelines: +1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs. +2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context. +3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish. + +Output only the revised prompt and nothing else.""" + +# docstyle-ignore +SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests). + +Rules: +- Single instruction only, no commentary +- Use clear, analytical language (avoid "whimsical," "cascading," etc.) +- Specify what changes AND what stays the same (face, lighting, composition) +- Reference actual image elements +- Turn negatives into positives ("don't change X" → "keep X") +- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels") +- Keep content PG-13 + +Output only the final instruction in plain text and nothing else.""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..140b9cc760cc1b163dbb7027e29d4a03f3c787d8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]} + +# Import transformers components so they can be resolved during pipeline loading + +if is_transformers_available() and is_transformers_version(">=", "4.57.4"): + try: + from transformers import GlmImageForConditionalGeneration, GlmImageProcessor + + _additional_imports["GlmImageForConditionalGeneration"] = GlmImageForConditionalGeneration + _additional_imports["GlmImageProcessor"] = GlmImageProcessor + except ImportError: + pass + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_glm_image import GlmImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/pipeline_glm_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/pipeline_glm_image.py new file mode 100644 index 0000000000000000000000000000000000000000..859b371b2514ad6f3ad026ea9239d2e8aadf20b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -0,0 +1,1049 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import ByT5Tokenizer, PreTrainedModel, ProcessorMixin, T5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, GlmImageTransformer2DModel +from ...models.transformers.transformer_glm_image import GlmImageKVCache +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, is_transformers_version, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import GlmImagePipelineOutput + + +# Because it's not released in stable as of 13/01/2026. So this is just a proxy. +GlmImageProcessor = ProcessorMixin +GlmImageForConditionalGeneration = PreTrainedModel +if is_transformers_version(">=", "5.0.0.dev0"): + from transformers import GlmImageForConditionalGeneration, GlmImageProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import GlmImagePipeline + + >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using GLM-Image. + + This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion + transformer) model for image decoding. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer for the text encoder. + processor (`AutoProcessor`): + Processor for the AR model to handle chat templates and tokenization. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder for glyph embeddings. + vision_language_encoder ([`GlmImageForConditionalGeneration`]): + The AR model that generates image tokens from text prompts. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`GlmImageTransformer2DModel`]): + A text conditioned transformer to denoise the encoded image latents (DiT). + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: ByT5Tokenizer, + processor: GlmImageProcessor, + text_encoder: T5EncoderModel, + vision_language_encoder: GlmImageForConditionalGeneration, + vae: AutoencoderKL, + transformer: GlmImageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + processor=processor, + text_encoder=text_encoder, + vision_language_encoder=vision_language_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") + and self.transformer is not None + and hasattr(self.transformer.config, "sample_size") + else 128 + ) + + @staticmethod + def _compute_generation_params( + image_grid_thw, + is_text_to_image: bool, + ): + grid_sizes = [] + grid_hw = [] + + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + grid_hw.append((int(h), int(w))) + + if not is_text_to_image: + max_new_tokens = grid_sizes[-1] + 1 + large_image_start_offset = 0 + target_grid_h, target_grid_w = grid_hw[-1] + else: + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + large_image_start_offset = sum(grid_sizes[1:]) + target_grid_h, target_grid_w = grid_hw[0] + return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w + + @staticmethod + def _extract_large_image_tokens( + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + token_ids = token_ids.view(1, -1) + return token_ids + + @staticmethod + def _validate_and_normalize_images( + image: list[PIL.Image.Image] | list[list[PIL.Image.Image]], + batch_size: int, + ) -> list[list[PIL.Image.Image]]: + """ + Validate and normalize image inputs to List[List[PIL.Image]]. + + Rules: + - batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length + - batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]]) + - Other formats raise ValueError + + Args: + image: Input images in various formats + batch_size: Number of prompts in the batch + + Returns: + Normalized images as List[List[PIL.Image]], or None if no images provided + """ + if image is None or len(image) == 0: + return None + + first_element = image[0] + + if batch_size == 1: + # Legacy format: List[PIL.Image] -> [[img1, img2, ...]] + if not isinstance(first_element, (list, tuple)): + return [list(image)] + # Already in List[List[PIL.Image]] format + if len(image) != 1: + raise ValueError( + f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}." + ) + return [list(image[0])] + + # batch_size > 1: must be List[List[PIL.Image]] + if not isinstance(first_element, (list, tuple)): + raise ValueError( + f"For batch_size > 1, images must be List[List[PIL.Image]] format. " + f"Got List[{type(first_element).__name__}] instead. " + f"Each prompt requires its own list of condition images." + ) + + if len(image) != batch_size: + raise ValueError(f"Number of image lists ({len(image)}) must match batch size ({batch_size}).") + + # Validate homogeneous: all sublists must have same length + num_input_images_per_prompt = len(image[0]) + for idx, imgs in enumerate(image): + if len(imgs) != num_input_images_per_prompt: + raise ValueError( + f"All prompts must have the same number of condition images. " + f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images." + ) + + return [list(imgs) for imgs in image] + + def generate_prior_tokens( + self, + prompt: str | list[str], + height: int, + width: int, + image: list[list[PIL.Image.Image]] | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + ): + """ + Generate prior tokens for the DiT model using the AR model. + + Args: + prompt: Single prompt or list of prompts + height: Target image height + width: Target image width + image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated + using _validate_and_normalize_images() before calling this method. + device: Target device + generator: Random generator for reproducibility + + Returns: + Tuple of: + - prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens + - prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains + the upsampled prior token ids for all condition images in that sample. None for t2i. + - source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape + (num_condition_images, 3) with upsampled grid info. None for t2i. + """ + device = device or self._execution_device + + # Normalize prompt to list format + prompt_list = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt_list) + + # Image is already normalized by _validate_and_normalize_images(): None or List[List[PIL.Image]] + is_text_to_image = image is None + # Build messages for each sample in the batch + all_messages = [] + for idx, p in enumerate(prompt_list): + content = [] + if not is_text_to_image: + for img in image[idx]: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": p}) + all_messages.append([{"role": "user", "content": content}]) + # Process with the processor (supports batch with left padding) + inputs = self.processor.apply_chat_template( + all_messages, + tokenize=True, + padding=True if batch_size > 1 else False, + target_h=height, + target_w=width, + return_dict=True, + return_tensors="pt", + ).to(device) + + image_grid_thw = inputs.get("image_grid_thw") + images_per_sample = inputs.get("images_per_sample") + + # Determine number of condition images and grids per sample + num_condition_images = 0 if is_text_to_image else len(image[0]) + if images_per_sample is not None: + num_grids_per_sample = images_per_sample[0].item() + else: + # Fallback for batch_size=1: total grids is for single sample + num_grids_per_sample = image_grid_thw.shape[0] + + # Compute generation params (same for all samples in homogeneous batch) + first_sample_grids = image_grid_thw[:num_grids_per_sample] + max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( + image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image + ) + + # Generate source image tokens (prior_token_image_ids) for i2i mode + prior_token_image_ids = None + source_image_grid_thw = None + if not is_text_to_image: + # Extract source grids by selecting condition image indices (skip target grids) + # Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...] + # We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...] + source_indices = [] + for sample_idx in range(batch_size): + base = sample_idx * num_grids_per_sample + source_indices.extend(range(base, base + num_condition_images)) + source_grids = image_grid_thw[source_indices] + + if len(source_grids) > 0: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], source_grids + ).pooler_output + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, source_grids + ) + # Upsample each source image's prior tokens to match VAE/DiT resolution + split_sizes = source_grids.prod(dim=-1).tolist() + prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes) + upsampled_prior_ids = [] + for i, prior_ids in enumerate(prior_ids_per_source): + t, h, w = source_grids[i].tolist() + upsampled = self._upsample_token_ids(prior_ids, int(h), int(w)) + upsampled_prior_ids.append(upsampled.squeeze(0)) + prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0) + # Upsample grid dimensions for later splitting + upsampled_grids = source_grids.clone() + upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2 + upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2 + source_image_grid_thw = upsampled_grids + + # Generate with AR model + # Set torch random seed from generator for reproducibility + # (transformers generate() doesn't accept generator parameter) + if generator is not None: + seed = generator.initial_seed() + torch.manual_seed(seed) + if device is not None and device.type == "cuda": + torch.cuda.manual_seed(seed) + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + # Extract and upsample prior tokens for each sample + # For left-padded inputs, generated tokens start after the padded input sequence + all_prior_token_ids = [] + max_input_length = inputs["input_ids"].shape[-1] + for idx in range(batch_size): + # For left-padded sequences, generated tokens start at max_input_length + # (padding is on the left, so all sequences end at the same position) + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs[idx : idx + 1], max_input_length, large_image_offset, token_h * token_w + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + all_prior_token_ids.append(prior_token_ids) + prior_token_ids = torch.cat(all_prior_token_ids, dim=0) + + # Split prior_token_image_ids and source_image_grid_thw into per-sample lists for easier consumption + prior_token_image_ids_per_sample = None + source_image_grid_thw_per_sample = None + if prior_token_image_ids is not None and source_image_grid_thw is not None: + # Split grids: each sample has num_condition_images grids + source_image_grid_thw_per_sample = list(torch.split(source_image_grid_thw, num_condition_images)) + # Split prior_token_image_ids: tokens per sample may vary due to different image sizes + tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist() + tokens_per_sample = [] + for i in range(batch_size): + start_idx = i * num_condition_images + end_idx = start_idx + num_condition_images + tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx])) + prior_token_image_ids_per_sample = list(torch.split(prior_token_image_ids, tokens_per_sample)) + + return prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample + + def get_glyph_texts(self, prompt): + """Extract glyph texts from prompt(s). Returns a list of lists for batch processing.""" + if isinstance(prompt, str): + prompt = [prompt] + all_ocr_texts = [] + for p in prompt: + ocr_texts = ( + re.findall(r"'([^']*)'", p) + + re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p) + + re.findall(r'"([^"]*)"', p) + + re.findall(r"「([^「」]*)」", p) + ) + all_ocr_texts.append(ocr_texts) + return all_ocr_texts + + def _get_glyph_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 2048, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """Get glyph embeddings for each prompt in the batch.""" + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + # get_glyph_texts now returns a list of lists (one per prompt) + all_glyph_texts = self.get_glyph_texts(prompt) + + all_glyph_embeds = [] + for glyph_texts in all_glyph_texts: + if len(glyph_texts) == 0: + glyph_texts = [""] + input_ids = self.tokenizer( + glyph_texts, + max_length=max_sequence_length, + truncation=True, + ).input_ids + input_ids = [ + [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids + ] + max_length = max(len(input_ids_) for input_ids_ in input_ids) + attention_mask = torch.tensor( + [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], + device=device, + ) + input_ids = torch.tensor( + [ + input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) + for input_ids_ in input_ids + ], + device=device, + ) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + all_glyph_embeds.append(glyph_embeds) + + # Pad to same sequence length and stack (use left padding to match transformers) + max_seq_len = max(emb.size(1) for emb in all_glyph_embeds) + padded_embeds = [] + for emb in all_glyph_embeds: + if emb.size(1) < max_seq_len: + pad = torch.zeros(emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype) + emb = torch.cat([pad, emb], dim=1) # left padding + padded_embeds.append(emb) + + glyph_embeds = torch.cat(padded_embeds, dim=0) + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 2048, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `2048`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + # Repeat embeddings for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + # For GLM-Image, negative_prompt must be "" instead of None + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + + if num_images_per_prompt > 1: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + prior_token_ids=None, + prior_token_image_ids=None, + source_image_grid_thw=None, + image=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * self.transformer.config.patch_size * 2) != 0 + or width is not None + and width % (self.transformer.config.patch_size * 2) != 0 + ): + # GLM-Image uses 32× downsampling, so the image dimensions must be multiples of 32. + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 4} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if prompt is None and prior_token_ids is None: + raise ValueError( + "Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Validate prior token inputs: for i2i mode, all three must be provided together + # For t2i mode, only prior_token_ids is needed (prior_token_image_ids and source_image_grid_thw should be None) + prior_image_inputs = [prior_token_image_ids, source_image_grid_thw] + num_prior_image_inputs = sum(x is not None for x in prior_image_inputs) + if num_prior_image_inputs > 0 and num_prior_image_inputs < len(prior_image_inputs): + raise ValueError( + "`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. " + f"Got prior_token_image_ids={prior_token_image_ids is not None}, " + f"source_image_grid_thw={source_image_grid_thw is not None}." + ) + if num_prior_image_inputs > 0 and prior_token_ids is None: + raise ValueError( + "`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided." + ) + if num_prior_image_inputs > 0 and image is None: + raise ValueError( + "`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided " + "for i2i mode, as the images are needed for VAE encoding to build the KV cache." + ) + + if prior_token_ids is not None and prompt_embeds is None and prompt is None: + raise ValueError("`prompt_embeds` or `prompt` must also be provided with `prior_token_ids`.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + image: torch.Tensor + | PIL.Image.Image + | np.ndarray + | list[torch.Tensor] + | list[PIL.Image.Image] + | list[np.ndarray] + | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + guidance_scale: float = 1.5, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prior_token_ids: torch.Tensor | None = None, + prior_token_image_ids: list[torch.Tensor] | None = None, + source_image_grid_thw: list[torch.Tensor] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 2048, + ) -> GlmImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. Must contain shape info in the format 'H + W' where H and W are token dimensions (d32). Example: "A beautiful sunset36 24" + generates a 1152x768 image. + image: Optional condition images for image-to-image generation. + height (`int`, *optional*): + The height in pixels. If not provided, derived from prompt shape info. + width (`int`, *optional*): + The width in pixels. If not provided, derived from prompt shape info. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps for DiT. + guidance_scale (`float`, *optional*, defaults to `1.5`): + Guidance scale for classifier-free guidance. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + Random generator for reproducibility. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format: "pil", "np", or "latent". + + Examples: + + Returns: + [`GlmImagePipelineOutput`] or `tuple`: Generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + prior_token_ids, + prior_token_image_ids, + source_image_grid_thw, + image, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 2. Validate and normalize image format + normalized_image = self._validate_and_normalize_images(image, batch_size) + + # 3. Generate prior tokens (batch mode) + # Get a single generator for AR model (use first if list provided) + ar_generator = generator[0] if isinstance(generator, list) else generator + if prior_token_ids is None: + prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample = ( + self.generate_prior_tokens( + prompt=prompt, + image=normalized_image, + height=height, + width=width, + device=device, + generator=ar_generator, + ) + ) + else: + # User provided prior_token_ids directly (from generate_prior_tokens) + prior_token_image_ids_per_sample = prior_token_image_ids + source_image_grid_thw_per_sample = source_image_grid_thw + + # 4. Preprocess images for VAE encoding + preprocessed_images = None + if normalized_image is not None: + preprocessed_images = [] + for prompt_images in normalized_image: + prompt_preprocessed = [] + for img in prompt_images: + image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width) + prompt_preprocessed.append(img) + height = height or image_height + width = width or image_width + preprocessed_images.append(prompt_preprocessed) + + # 5. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.dtype, + ) + + # 6. Prepare latents and (optional) image kv cache + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) + + if normalized_image is not None: + kv_caches.set_mode("write") + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1) + + latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype) + latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype) + + # Process each sample's condition images + for prompt_idx in range(batch_size): + prompt_images = preprocessed_images[prompt_idx] + prompt_prior_ids = prior_token_image_ids_per_sample[prompt_idx] + prompt_grid_thw = source_image_grid_thw_per_sample[prompt_idx] + + # Split this sample's prior_token_image_ids by each image's token count + split_sizes = prompt_grid_thw.prod(dim=-1).tolist() + prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes) + # Process each condition image for this sample + for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image): + condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype) + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_image_prior_token_id, + prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=device), + target_size=torch.tensor([condition_image.shape[-2:]], device=device), + crop_coords=torch.zeros((1, 2), device=device), + attention_kwargs=attention_kwargs, + kv_caches=kv_caches, + ) + # Move to next sample's cache slot + kv_caches.next_sample() + + # 7. Prepare additional timestep conditions + target_size = (height, width) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # 8. Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # Repeat prior_token_ids for num_images_per_prompt + if num_images_per_prompt > 1: + prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0) + prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + + timestep = t.expand(latents.shape[0]) - 1 + + if prior_token_image_ids_per_sample is not None: + kv_caches.set_mode("read") + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_ids, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + kv_caches=kv_caches, + )[0].float() + + # perform guidance + if self.do_classifier_free_guidance: + if prior_token_image_ids_per_sample is not None: + kv_caches.set_mode("skip") + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_ids, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + kv_caches=kv_caches, + )[0].float() + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + kv_caches.clear() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return GlmImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..d4fd061335d4fc7cc4f0a03c25acccb78393b166 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/glm_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class GlmImagePipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae08f5997279ff7417300e94f44787b34e4faa0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_helios"] = ["HeliosPipeline"] + _import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_helios import HeliosPipeline + from .pipeline_helios_pyramid import HeliosPyramidPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_helios.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_helios.py new file mode 100644 index 0000000000000000000000000000000000000000..87a8600badaba7ad8013e9d5818f5e125d0cd951 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_helios.py @@ -0,0 +1,916 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import numpy as np +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HeliosLoraLoaderMixin +from ...models import AutoencoderKLWan, HeliosTransformer3DModel +from ...schedulers import HeliosScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HeliosPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, HeliosPipeline + + >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled + >>> model_id = "BestWishYsh/Helios-Base" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = HeliosPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=384, + ... width=640, + ... num_frames=132, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): + r""" + Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`HeliosTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`HeliosScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: HeliosScheduler, + transformer: HeliosTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, text_inputs.attention_mask.bool() + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + video=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 384, + width: int = 640, + num_frames: int = 33, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + fake_latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = (latents - latents_mean) * latents_std + if fake_latents is None: + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype) + fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_latents = fake_latents_full[:, :, -1:, :, :] + return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype) + + def prepare_video_latents( + self, + video: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + video = video.to(device=device, dtype=self.vae.dtype) + if latents is None: + num_frames = video.shape[2] + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + first_frame = video[:, :, 0:1, :, :] + first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator) + first_frame_latent = (first_frame_latent - latents_mean) * latents_std + + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.append(chunk_latents) + latents = torch.cat(latents_chunks, dim=2) + return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 384, + width: int = 640, + num_frames: int = 132, + num_inference_steps: int = 50, + sigmas: list[float] = None, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + # ------------ I2V ------------ + image: PipelineImageInput | None = None, + image_latents: torch.Tensor | None = None, + fake_image_latents: torch.Tensor | None = None, + add_noise_to_image_latents: bool = True, + image_noise_sigma_min: float = 0.111, + image_noise_sigma_max: float = 0.135, + # ------------ V2V ------------ + video: PipelineImageInput | None = None, + video_latents: torch.Tensor | None = None, + add_noise_to_video_latents: bool = True, + video_noise_sigma_min: float = 0.111, + video_noise_sigma_max: float = 0.135, + # ------------ Stage 1 ------------ + history_sizes: list = [16, 2, 1], + num_latent_frames_per_chunk: int = 9, + keep_first_frame: bool = True, + is_skip_first_chunk: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `384`): + The height in pixels of the generated image. + width (`int`, defaults to `640`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `132`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~HeliosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + history_sizes = sorted(history_sizes, reverse=True) # From big to small + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + image, + video, + ) + + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + vae_dtype = self.vae.dtype + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image or video + if image is not None: + image = self.video_processor.preprocess(image, height=height, width=width) + image_latents, fake_image_latents = self.prepare_image_latents( + image, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=image_latents, + fake_latents=fake_image_latents, + ) + + if image_latents is not None and add_noise_to_image_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + if video is not None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + image_latents, video_latents = self.prepare_video_latents( + video, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=video_latents, + ) + + if video_latents is not None and add_noise_to_video_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk + for i in range(num_latent_chunks): + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=generator) + * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) + history_video = None + total_generated_latent_frames = 0 + + if not keep_first_frame: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + num_history_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2) + total_generated_latent_frames += 1 + if video_latents is not None: + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + total_generated_latent_frames += video_latents.shape[2] + + if keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + # 6. Denoising loop + patch_size = self.transformer.config.patch_size + image_seq_len = ( + num_latent_frames_per_chunk + * (height // self.vae_scale_factor_spatial) + * (width // self.vae_scale_factor_spatial) + // (patch_size[0] * patch_size[1] * patch_size[2]) + ) + sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + for k in range(num_latent_chunk): + is_first_chunk = k == 0 + is_second_chunk = k == 1 + if keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=device, + dtype=latents_history_1x.dtype, + ) + else: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu) + timesteps = self.scheduler.timesteps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]) + + latent_model_input = latents.to(transformer_dtype) + latents_history_short = latents_history_short.to(transformer_dtype) + latents_history_mid = latents_history_mid.to(transformer_dtype) + latents_history_long = latents_history_long.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + return_dict=False, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if keep_first_frame and ( + (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) + ): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + current_latents = ( + real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + + latents_mean + ) + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + self._current_timestep = None + + if output_type != "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = real_history_latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HeliosPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_helios_pyramid.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_helios_pyramid.py new file mode 100644 index 0000000000000000000000000000000000000000..40c1d65825ff4698bb11c75e7a6c40321b606444 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_helios_pyramid.py @@ -0,0 +1,1065 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +from typing import Any, Callable + +import regex as re +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HeliosLoraLoaderMixin +from ...models import AutoencoderKLWan, HeliosTransformer3DModel +from ...schedulers import HeliosDMDScheduler, HeliosScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HeliosPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, HeliosPyramidPipeline + + >>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled + >>> model_id = "BestWishYsh/Helios-Base" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = HeliosPyramidPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=384, + ... width=640, + ... num_frames=132, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def optimized_scale(positive_flat, negative_flat): + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + return st_star + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin): + r""" + Pipeline for text-to-video / image-to-video / video-to-video generation using Helios. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`HeliosTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`HeliosScheduler`, `HeliosDMDScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: HeliosScheduler | HeliosDMDScheduler, + transformer: HeliosTransformer3DModel, + is_cfg_zero_star: bool = False, + is_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.register_to_config(is_cfg_zero_star=is_cfg_zero_star) + self.register_to_config(is_distilled=is_distilled) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.helios.pipeline_helios.HeliosPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, text_inputs.attention_mask.bool() + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, _ = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + video=None, + guidance_scale=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if image is not None and video is not None: + raise ValueError("image and video cannot be provided simultaneously") + + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 384, + width: int = 640, + num_frames: int = 33, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + fake_latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = (latents - latents_mean) * latents_std + if fake_latents is None: + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype) + fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents_full = (fake_latents_full - latents_mean) * latents_std + fake_latents = fake_latents_full[:, :, -1:, :, :] + return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype) + + def prepare_video_latents( + self, + video: torch.Tensor, + latents_mean: torch.Tensor, + latents_std: torch.Tensor, + num_latent_frames_per_chunk: int, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + video = video.to(device=device, dtype=self.vae.dtype) + if latents is None: + num_frames = video.shape[2] + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_chunks = num_frames // min_frames + if num_chunks == 0: + raise ValueError( + f"Video must have at least {min_frames} frames " + f"(got {num_frames} frames). " + f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}" + ) + total_valid_frames = num_chunks * min_frames + start_frame = num_frames - total_valid_frames + + first_frame = video[:, :, 0:1, :, :] + first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator) + first_frame_latent = (first_frame_latent - latents_mean) * latents_std + + latents_chunks = [] + for i in range(num_chunks): + chunk_start = start_frame + i * min_frames + chunk_end = chunk_start + min_frames + video_chunk = video[:, :, chunk_start:chunk_end, :, :] + chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator) + chunk_latents = (chunk_latents - latents_mean) * latents_std + latents_chunks.append(chunk_latents) + latents = torch.cat(latents_chunks, dim=2) + return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype) + + def sample_block_noise( + self, + batch_size, + channel, + num_frames, + height, + width, + patch_size: tuple[int, ...] = (1, 2, 2), + device: torch.device | None = None, + ): + gamma = self.scheduler.config.gamma + _, ph, pw = patch_size + block_size = ph * pw + + cov = ( + torch.eye(block_size, device=device) * (1 + gamma) + - torch.ones(block_size, block_size, device=device) * gamma + ) + cov += torch.eye(block_size, device=device) * 1e-6 + dist = torch.distributions.MultivariateNormal(torch.zeros(block_size, device=device), covariance_matrix=cov) + block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + + noise = dist.sample((block_number,)) # [block number, block_size] + noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) + noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + return noise + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 384, + width: int = 640, + num_frames: int = 132, + sigmas: list[float] = None, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + # ------------ I2V ------------ + image: PipelineImageInput | None = None, + image_latents: torch.Tensor | None = None, + fake_image_latents: torch.Tensor | None = None, + add_noise_to_image_latents: bool = True, + image_noise_sigma_min: float = 0.111, + image_noise_sigma_max: float = 0.135, + # ------------ V2V ------------ + video: PipelineImageInput | None = None, + video_latents: torch.Tensor | None = None, + add_noise_to_video_latents: bool = True, + video_noise_sigma_min: float = 0.111, + video_noise_sigma_max: float = 0.135, + # ------------ Stage 1 ------------ + history_sizes: list = [16, 2, 1], + num_latent_frames_per_chunk: int = 9, + keep_first_frame: bool = True, + is_skip_first_chunk: bool = False, + # ------------ Stage 2 ------------ + pyramid_num_inference_steps_list: list = [10, 10, 10], + # ------------ CFG Zero ------------ + use_zero_init: bool | None = True, + zero_steps: int | None = 1, + # ------------ DMD ------------ + is_amplify_first_chunk: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `384`): + The height in pixels of the generated image. + width (`int`, defaults to `640`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `132`): + The number of frames in the generated video. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~HeliosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + history_sizes = sorted(history_sizes, reverse=True) # From big to small + pyramid_num_stages = len(pyramid_num_inference_steps_list) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + image, + video, + guidance_scale, + ) + + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + vae_dtype = self.vae.dtype + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image or video + if image is not None: + image = self.video_processor.preprocess(image, height=height, width=width) + image_latents, fake_image_latents = self.prepare_image_latents( + image, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=image_latents, + fake_latents=fake_image_latents, + ) + + if image_latents is not None and add_noise_to_image_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + + if video is not None: + video = self.video_processor.preprocess_video(video, height=height, width=width) + image_latents, video_latents = self.prepare_video_latents( + video, + latents_mean=latents_mean, + latents_std=latents_std, + num_latent_frames_per_chunk=num_latent_frames_per_chunk, + dtype=torch.float32, + device=device, + generator=generator, + latents=video_latents, + ) + + if video_latents is not None and add_noise_to_video_latents: + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device) + + (1 - image_noise_sigma) * image_latents + ) + + noisy_latents_chunks = [] + num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk + for i in range(num_latent_chunks): + chunk_start = i * num_latent_frames_per_chunk + chunk_end = chunk_start + num_latent_frames_per_chunk + latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :] + + chunk_frames = latent_chunk.shape[2] + frame_sigmas = ( + torch.rand(chunk_frames, device=device, generator=generator) + * (video_noise_sigma_max - video_noise_sigma_min) + + video_noise_sigma_min + ) + frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1) + + noisy_chunk = ( + frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device) + + (1 - frame_sigmas) * latent_chunk + ) + noisy_latents_chunks.append(noisy_chunk) + video_latents = torch.cat(noisy_latents_chunks, dim=2) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) + history_video = None + total_generated_latent_frames = 0 + + if not keep_first_frame: + history_sizes[-1] = history_sizes[-1] + 1 + history_latents = torch.zeros( + batch_size, + num_channels_latents, + num_history_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2) + total_generated_latent_frames += 1 + if video_latents is not None: + history_frames = history_latents.shape[2] + video_frames = video_latents.shape[2] + if video_frames < history_frames: + keep_frames = history_frames - video_frames + history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2) + else: + history_latents = video_latents + total_generated_latent_frames += video_latents.shape[2] + + if keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk])) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk])) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + # 6. Denoising loop + for k in range(num_latent_chunk): + is_first_chunk = k == 0 + is_second_chunk = k == 1 + if keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros( + ( + batch_size, + num_channels_latents, + 1, + latents_history_1x.shape[-2], + latents_history_1x.shape[-1], + ), + device=device, + dtype=latents_history_1x.dtype, + ) + else: + latents_prefix = image_latents + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = history_latents[ + :, :, -num_history_latent_frames: + ].split(history_sizes, dim=2) + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + num_inference_steps = ( + sum(pyramid_num_inference_steps_list) * 2 + if is_amplify_first_chunk and self.config.is_distilled and is_first_chunk + else sum(pyramid_num_inference_steps_list) + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + _, _, _, pyramid_height, pyramid_width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width + ) + for _ in range(pyramid_num_stages - 1): + pyramid_height //= 2 + pyramid_width //= 2 + latents = ( + F.interpolate( + latents, + size=(pyramid_height, pyramid_width), + mode="bilinear", + ) + * 2 + ) + latents = latents.reshape( + batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width + ).permute(0, 2, 1, 3, 4) + + start_point_list = None + if self.config.is_distilled: + start_point_list = [latents] + + for stage_idx in range(pyramid_num_stages): + patch_size = self.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( + patch_size[0] * patch_size[1] * patch_size[2] + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.set_timesteps( + pyramid_num_inference_steps_list[stage_idx], + stage_idx, + device=device, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk, + ) + timesteps = self.scheduler.timesteps + num_warmup_steps = 0 + self._num_timesteps = len(timesteps) + + if stage_idx > 0: + pyramid_height *= 2 + pyramid_width *= 2 + num_frames = latents.shape[2] + latents = latents.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_latent_frames_per_chunk, + num_channels_latents, + pyramid_height // 2, + pyramid_width // 2, + ) + latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest") + latents = latents.reshape( + batch_size, + num_latent_frames_per_chunk, + num_channels_latents, + pyramid_height, + pyramid_width, + ).permute(0, 2, 1, 3, 4) + # Fix the stage + ori_sigma = 1 - self.scheduler.ori_start_sigmas[stage_idx] # the original coeff of signal + gamma = self.scheduler.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + + batch_size, channel, num_frames, pyramid_height, pyramid_width = latents.shape + noise = self.sample_block_noise( + batch_size, channel, num_frames, pyramid_height, pyramid_width, patch_size, device + ) + noise = noise.to(device=device, dtype=transformer_dtype) + latents = alpha * latents + beta * noise # To fix the block artifact + + if self.config.is_distilled: + start_point_list.append(latents) + + for i, t in enumerate(timesteps): + timestep = t.expand(latents.shape[0]).to(torch.int64) + + latent_model_input = latents.to(transformer_dtype) + latents_history_short = latents_history_short.to(transformer_dtype) + latents_history_mid = latents_history_mid.to(transformer_dtype) + latents_history_long = latents_history_long.to(transformer_dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + indices_hidden_states=indices_hidden_states, + indices_latents_history_short=indices_latents_history_short, + indices_latents_history_mid=indices_latents_history_mid, + indices_latents_history_long=indices_latents_history_long, + latents_history_short=latents_history_short, + latents_history_mid=latents_history_mid, + latents_history_long=latents_history_long, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.config.is_cfg_zero_star: + noise_pred_text = noise_pred + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_uncond.view(batch_size, -1) + + alpha = optimized_scale(positive_flat, negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(noise_pred_text.shape) - 1))) + alpha = alpha.to(noise_pred_text.dtype) + + if (stage_idx == 0 and i <= zero_steps) and use_zero_init: + noise_pred = noise_pred_text * 0.0 + else: + noise_pred = noise_uncond * alpha + guidance_scale * ( + noise_pred_text - noise_uncond * alpha + ) + else: + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + extra_kwargs = ( + { + "cur_sampling_step": i, + "dmd_noisy_tensor": start_point_list[stage_idx] + if start_point_list is not None + else None, + "dmd_sigmas": self.scheduler.sigmas, + "dmd_timesteps": self.scheduler.timesteps, + "all_timesteps": timesteps, + } + if self.config.is_distilled + else {} + ) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + return_dict=False, + **extra_kwargs, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if keep_first_frame and ( + (is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk) + ): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + current_latents = ( + real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std + + latents_mean + ) + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if history_video is None: + history_video = current_video + else: + history_video = torch.cat([history_video, current_video], dim=2) + + self._current_timestep = None + + if output_type != "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = real_history_latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HeliosPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..08546289ef4c0739916c3106b8d9e6a93120d64a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/helios/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HeliosPipelineOutput(BaseOutput): + r""" + Output class for Helios pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..498df900e68b70920e5fc49764eb6bdcd2dc3354 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["HiDreamImagePipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hidream_image"] = ["HiDreamImagePipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_hidream_image import HiDreamImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/pipeline_hidream_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5e078cc2af6ce67489f21d483cf9b283c21440 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -0,0 +1,1053 @@ +# Copyright 2025 HiDream-ai Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + LlamaForCausalLM, + PreTrainedTokenizerFast, + T5EncoderModel, + T5Tokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import HiDreamImageLoraLoaderMixin +from ...models import AutoencoderKL, HiDreamImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HiDreamImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from diffusers import HiDreamImagePipeline + + >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") + >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( + ... "meta-llama/Meta-Llama-3.1-8B-Instruct", + ... output_hidden_states=True, + ... output_attentions=True, + ... torch_dtype=torch.bfloat16, + ... ) + + >>> pipe = HiDreamImagePipeline.from_pretrained( + ... "HiDream-ai/HiDream-I1-Full", + ... tokenizer_4=tokenizer_4, + ... text_encoder_4=text_encoder_4, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> image = pipe( + ... 'A cat holding a sign that says "Hi-Dreams.ai".', + ... height=1024, + ... width=1024, + ... guidance_scale=5.0, + ... num_inference_steps=50, + ... generator=torch.Generator("cuda").manual_seed(0), + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5Tokenizer, + text_encoder_4: LlamaForCausalLM, + tokenizer_4: PreTrainedTokenizerFast, + transformer: HiDreamImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + text_encoder_4=text_encoder_4, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + tokenizer_4=tokenizer_4, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + if getattr(self, "tokenizer_4", None) is not None: + self.tokenizer_4.pad_token = self.tokenizer_4.eos_token + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_3.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_3.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + tokenizer, + text_encoder, + prompt: str | list[str], + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=min(max_sequence_length, 218), + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {218} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def _get_llama3_prompt_embeds( + self, + prompt: str | list[str] = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder_4.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_4( + prompt, + padding="max_length", + max_length=min(max_sequence_length, self.tokenizer_4.model_max_length), + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_4.batch_decode( + untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}" + ) + + outputs = self.text_encoder_4( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + output_hidden_states=True, + output_attentions=True, + ) + + prompt_embeds = outputs.hidden_states[1:] + prompt_embeds = torch.stack(prompt_embeds, dim=0) + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + prompt_4: str | list[str] | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + negative_prompt_4: str | list[str] | None = None, + prompt_embeds_t5: list[torch.FloatTensor] | None = None, + prompt_embeds_llama3: list[torch.FloatTensor] | None = None, + negative_prompt_embeds_t5: list[torch.FloatTensor] | None = None, + negative_prompt_embeds_llama3: list[torch.FloatTensor] | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 128, + lora_scale: float | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = pooled_prompt_embeds.shape[0] + + device = device or self._execution_device + + if pooled_prompt_embeds is None: + pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype + ) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if len(negative_prompt) > 1 and len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds( + self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if len(prompt_2) > 1 and len(prompt_2) != batch_size: + raise ValueError(f"prompt_2 must be of length 1 or {batch_size}") + + pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype + ) + + if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + + if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size: + raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}") + + negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds( + self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype + ) + + if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1: + negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1) + + if pooled_prompt_embeds is None: + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1) + + if do_classifier_free_guidance and negative_pooled_prompt_embeds is None: + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1 + ) + + if prompt_embeds_t5 is None: + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + if len(prompt_3) > 1 and len(prompt_3) != batch_size: + raise ValueError(f"prompt_3 must be of length 1 or {batch_size}") + + prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype) + + if prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_t5 is None: + negative_prompt_3 = negative_prompt_3 or negative_prompt + negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + + if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size: + raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}") + + negative_prompt_embeds_t5 = self._get_t5_prompt_embeds( + negative_prompt_3, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + + if prompt_embeds_llama3 is None: + prompt_4 = prompt_4 or prompt + prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4 + + if len(prompt_4) > 1 and len(prompt_4) != batch_size: + raise ValueError(f"prompt_4 must be of length 1 or {batch_size}") + + prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype) + + if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None: + negative_prompt_4 = negative_prompt_4 or negative_prompt + negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4 + + if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size: + raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}") + + negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds( + negative_prompt_4, max_sequence_length, device, dtype + ) + + if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + + # duplicate pooled_prompt_embeds for each generation per prompt + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}") + prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}") + prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim) + + if do_classifier_free_guidance: + # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len = negative_pooled_prompt_embeds.shape + if bs_embed == 1 and batch_size > 1: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}") + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + # duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt + bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}") + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1) + + # duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt + _, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape + if bs_embed == 1 and batch_size > 1: + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1) + elif bs_embed > 1 and bs_embed != batch_size: + raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}") + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1) + negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view( + -1, batch_size * num_images_per_prompt, seq_len, dim + ) + + return ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + negative_prompt_4=None, + prompt_embeds_t5=None, + prompt_embeds_llama3=None, + negative_prompt_embeds_t5=None, + negative_prompt_embeds_llama3=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to" + " only forward one of the two." + ) + elif prompt_4 is not None and prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and pooled_prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_t5 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined." + ) + elif prompt is None and prompt_embeds_llama3 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)): + raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}") + + if negative_prompt is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:" + f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:" + f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two." + ) + elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:" + f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two." + ) + + if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None: + if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape: + raise ValueError( + "`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but" + f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`" + f" {negative_pooled_prompt_embeds.shape}." + ) + if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None: + if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape: + raise ValueError( + "`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but" + f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`" + f" {negative_prompt_embeds_t5.shape}." + ) + if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None: + if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape: + raise ValueError( + "`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but" + f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`" + f" {negative_prompt_embeds_llama3.shape}." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + prompt_4: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + negative_prompt_4: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds_t5: torch.FloatTensor | None = None, + prompt_embeds_llama3: torch.FloatTensor | None = None, + negative_prompt_embeds_t5: torch.FloatTensor | None = None, + negative_prompt_embeds_llama3: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 128, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead. + prompt_4 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_4` and `text_encoder_4`. If not defined, `prompt` is + will be used instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_4 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_4` and + `text_encoder_4`. If not defined, `negative_prompt` is used in all the text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] or `tuple`: + [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated. images. + """ + + prompt_embeds = kwargs.get("prompt_embeds", None) + negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None) + + if prompt_embeds is not None: + deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead." + deprecate("prompt_embeds", "0.35.0", deprecation_message) + prompt_embeds_t5 = prompt_embeds[0] + prompt_embeds_llama3 = prompt_embeds[1] + + if negative_prompt_embeds is not None: + deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead." + deprecate("negative_prompt_embeds", "0.35.0", deprecation_message) + negative_prompt_embeds_t5 = negative_prompt_embeds[0] + negative_prompt_embeds_llama3 = negative_prompt_embeds[1] + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + division = self.vae_scale_factor * 2 + S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 + scale = S_max / (width * height) + scale = math.sqrt(scale) + width, height = int(width * scale // division * division), int(height * scale // division * division) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif pooled_prompt_embeds is not None: + batch_size = pooled_prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode prompt + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + ( + prompt_embeds_t5, + negative_prompt_embeds_t5, + prompt_embeds_llama3, + negative_prompt_embeds_llama3, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + prompt_4=prompt_4, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + negative_prompt_4=negative_prompt_4, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds_t5=prompt_embeds_t5, + prompt_embeds_llama3=prompt_embeds_llama3, + negative_prompt_embeds_t5=negative_prompt_embeds_t5, + negative_prompt_embeds_llama3=negative_prompt_embeds_llama3, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0) + prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + pooled_prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + mu = calculate_shift(self.transformer.max_seq) + scheduler_kwargs = {"mu": mu} + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + if isinstance(self.scheduler, UniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=timestep_device) # , shift=math.exp(mu)) + timesteps = self.scheduler.timesteps + else: + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timesteps=timestep, + encoder_hidden_states_t5=prompt_embeds_t5, + encoder_hidden_states_llama3=prompt_embeds_llama3, + pooled_embeds=pooled_prompt_embeds, + return_dict=False, + )[0] + noise_pred = -noise_pred + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5) + prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HiDreamImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..1802c7220691d5b14c99580de3a6bae5846b838d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hidream_image/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2025 HiDream-ai Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class HiDreamImagePipelineOutput(BaseOutput): + """ + Output class for HiDreamImage pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7da72fa12b2c18c1b9ec7f9fb9b2170e4b516041 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"] + _import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuanimage import HunyuanImagePipeline + from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py new file mode 100644 index 0000000000000000000000000000000000000000..50239e9afa22678bfe1532291fa75cda6e756988 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py @@ -0,0 +1,868 @@ +# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import re +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel + +from ...guiders import AdaptiveProjectedMixGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanImagePipeline + + >>> pipe = HunyuanImagePipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0] + >>> image.save("hunyuanimage.png") + ``` +""" + + +def extract_glyph_text(prompt: str): + """ + Extract text enclosed in quotes for glyph rendering. + + Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing. + + Args: + prompt: Input text prompt + + Returns: + Formatted glyph text string or None if no quoted text found + """ + text_prompt_texts = [] + pattern_quote_single = r"\'(.*?)\'" + pattern_quote_double = r"\"(.*?)\"" + pattern_quote_chinese_single = r"‘(.*?)’" + pattern_quote_chinese_double = r"“(.*?)”" + + matches_quote_single = re.findall(pattern_quote_single, prompt) + matches_quote_double = re.findall(pattern_quote_double, prompt) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt) + + text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if text_prompt_texts: + glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". " + else: + glyph_text_formatted = None + + return glyph_text_formatted + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanImagePipeline(DiffusionPipeline): + r""" + The HunyuanImage pipeline for text-to-image generation. + + Args: + transformer ([`HunyuanImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanImage`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`AdaptiveProjectedMixGuidance`]): + [AdaptiveProjectedMixGuidance]to be used to guide the image generation. + ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*): + [AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["ocr_guider", "guider"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLHunyuanImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + transformer: HunyuanImageTransformer2DModel, + guider: AdaptiveProjectedMixGuidance | None = None, + ocr_guider: AdaptiveProjectedMixGuidance | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + guider=guider, + ocr_guider=ocr_guider, + ) + + self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 128 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 64 + + def _get_qwen_prompt_embeds( + self, + tokenizer: Qwen2Tokenizer, + text_encoder: Qwen2_5_VLForConditionalGeneration, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + tokenizer_max_length: int = 1000, + template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>", + drop_idx: int = 34, + hidden_state_skip_layer: int = 2, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt" + ).to(device) + + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)] + + prompt_embeds = prompt_embeds[:, drop_idx:] + encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + encoder_attention_mask = encoder_attention_mask.to(device=device) + + return prompt_embeds, encoder_attention_mask + + def _get_byt5_prompt_embeds( + self, + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + tokenizer_max_length: int = 128, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + if isinstance(prompt, list): + raise ValueError("byt5 prompt should be a string") + elif prompt is None: + raise ValueError("byt5 prompt should not be None") + + txt_tokens = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + prompt_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + encoder_attention_mask = txt_tokens.attention_mask.to(device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + batch_size: int = 1, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + template=self.prompt_template_encode, + drop_idx=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2_list = [] + prompt_embeds_mask_2_list = [] + + glyph_texts = [extract_glyph_text(p) for p in prompt] + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device + ) + glyph_text_embeds_mask = torch.zeros( + (1, self.tokenizer_2_max_length), device=device, dtype=torch.int64 + ) + else: + glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=glyph_text, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + prompt_embeds_2_list.append(glyph_text_embeds) + prompt_embeds_mask_2_list.append(glyph_text_embeds_mask) + + prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0) + prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + distilled_guidance_scale: float | None = 3.25, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds_mask_2: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is + not provided, will use an empty negative prompt. Ignored when not using guidance. ). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + distilled_guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For + guidance distilled models, this parameter is required. For non-distilled models, this parameter will be + ignored. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, text embeddings mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt` + input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt` + input argument. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from + `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare prompt embeds + + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + prompt_embeds = prompt_embeds.to(self.transformer.dtype) + prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype) + + # select guider + if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None: + # prompt contains ocr and pipeline has a guider for ocr + guider = self.ocr_guider + elif self.guider is not None: + guider = self.guider + # distilled model does not use guidance method, use default guider with enabled=False + else: + guider = AdaptiveProjectedMixGuidance(enabled=False) + + if guider._enabled and guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype) + negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance (for guidance-distilled model) + if self.transformer.config.guidance_embeds and distilled_guidance_scale is None: + raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.") + + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor( + [distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device + ) + * 1000.0 + ) + + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if self.transformer.config.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + else: + timestep_r = None + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep, + timestep_r=timestep_r, + guidance=guidance, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HunyuanImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..93e4deb2974a9f9854c49cd45bcaec1b7e7a4257 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py @@ -0,0 +1,754 @@ +# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...guiders import AdaptiveProjectedMixGuidance +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanImageRefinerPipeline + + >>> pipe = HunyuanImageRefinerPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = load_image("path/to/image.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, image=image, num_inference_steps=4).images[0] + >>> image.save("hunyuanimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanImageRefinerPipeline(DiffusionPipeline): + r""" + The HunyuanImage pipeline for text-to-image generation. + + Args: + transformer ([`HunyuanImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanImageRefiner`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["guider"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLHunyuanImageRefiner, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanImageTransformer2DModel, + guider: AdaptiveProjectedMixGuidance | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + guider=guider, + ) + + self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = 256 + self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + self.prompt_template_encode_start_idx = 36 + self.default_sample_size = 64 + self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64 + + # Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + tokenizer: Qwen2Tokenizer, + text_encoder: Qwen2_5_VLForConditionalGeneration, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + tokenizer_max_length: int = 1000, + template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>", + drop_idx: int = 34, + hidden_state_skip_layer: int = 2, + ): + device = device or self._execution_device + dtype = dtype or text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer( + txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt" + ).to(device) + + encoder_hidden_states = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)] + + prompt_embeds = prompt_embeds[:, drop_idx:] + encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + encoder_attention_mask = encoder_attention_mask.to(device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str] | None = None, + device: torch.device | None = None, + batch_size: int = 1, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + template=self.prompt_template_encode, + drop_idx=self.prompt_template_encode_start_idx, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + def prepare_latents( + self, + image_latents, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + strength=0.25, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, 1, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + cond_latents = strength * noise + (1 - strength) * image_latents + + return latents, cond_latents + + @staticmethod + def _reorder_image_tokens(image_latents): + image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2) + batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape + image_latents = image_latents.permute(0, 2, 1, 3, 4) + image_latents = image_latents.reshape( + batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width + ) + image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous() + + return image_latents + + @staticmethod + def _restore_image_tokens_order(latents): + """Restore image tokens order by splitting channels and removing first frame slice.""" + batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape + + latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W + latents = latents.reshape( + batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width + ) # B, F*2, C//2, H, W + + latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W + # Remove first frame slice + latents = latents[:, :, 1:] + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample") + image_latents = self._reorder_image_tokens(image_latents) + + image_latents = image_latents * self.vae.config.scaling_factor + + return image_latents + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + distilled_guidance_scale: float | None = 3.25, + image: PipelineImageInput | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 4, + sigmas: list[float] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, will use an empty negative + prompt. Ignored when not using guidance. + distilled_guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate + images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For + guidance distilled models, this parameter is required. For non-distilled models, this parameter will be + ignored. + num_images_per_prompt (`int`, *optional*, defaults to 1): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`: + [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. process image + if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels: + image_latents = image + else: + image = self.image_processor.preprocess(image, height, width) + image = image.unsqueeze(2).to(device, dtype=self.vae.dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + # 3.prepare prompt embeds + + if self.guider is not None: + guider = self.guider + else: + # distilled model does not use guidance method, use default guider with enabled=False + guider = AdaptiveProjectedMixGuidance(enabled=False) + + requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1 + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + prompt_embeds = prompt_embeds.to(self.transformer.dtype) + + if requires_unconditional_embeds: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype) + + # 4. Prepare latent variables + latents, cond_latents = self.prepare_latents( + image_latents=image_latents, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=self.latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance (this pipeline only supports guidance-distilled models) + if distilled_guidance_scale is None: + raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.") + guidance = ( + torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device) + * 1000.0 + ) + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype) + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + } + + # Step 2: Update guider's internal state for this denoising step + guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = guider.prepare_inputs(guider_inputs) + + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + guidance=guidance, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + latents = self._restore_image_tokens_order(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return HunyuanImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..169436b7d86f78d4584396a42a2d3f1d0209d3a3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class HunyuanImagePipelineOutput(BaseOutput): + """ + Output class for HunyuanImage pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d42d38fac9793db2c2fb94c4b41d3755be75d898 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"] + _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] + _import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"] + _import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline + from .pipeline_hunyuan_video import HunyuanVideoPipeline + from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline + from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7cae256d638ae2a02d23df33e9321319ab36c2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py @@ -0,0 +1,832 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import load_image, export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer_model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... transformer_model_id, torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanSkyreelsImageToVideoPipeline.from_pretrained( + ... model_id, transformer=transformer, torch_dtype=torch.float16 + ... ) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> negative_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=30, + ... true_cfg_scale=6.0, + ... guidance_scale=1.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds + def _get_llama_prompt_embeds( + self, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int = 32, + height: int = 544, + width: int = 960, + num_frames: int = 97, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + padding_shape = (batch_size, num_channels_latents, num_latent_frames - 1, latent_height, latent_width) + + latents_padding = torch.zeros(padding_shape, dtype=dtype, device=device) + image_latents = torch.cat([image_latents, latents_padding], dim=2) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype, device=device) + else: + latents = latents.to(dtype=dtype, device=device) + + return latents, image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 6.0, + guidance_scale: float = 1.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latent variables + vae_dtype = self.vae.dtype + image = self.video_processor.preprocess(image, height=height, width=width).to(device, vae_dtype) + num_channels_latents = self.transformer.config.in_channels // 2 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + latent_image_input = image_latents.to(transformer_dtype) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py new file mode 100644 index 0000000000000000000000000000000000000000..3c6ec39398efd65f060b7e20fae2485d92eeae23 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -0,0 +1,783 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + guidance_scale (`float`, defaults to `6.0`): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py new file mode 100644 index 0000000000000000000000000000000000000000..f82f26eea5b9a723b79039376480d49e2517ffbb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -0,0 +1,1134 @@ +# Copyright 2025 The Framepack Team, The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from enum import Enum +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + CLIPTextModel, + CLIPTokenizer, + LlamaModel, + LlamaTokenizerFast, + SiglipImageProcessor, + SiglipVisionModel, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoFramepackTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoFramepackPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi): We can pack the checkpoints nicely with modular loader +EXAMPLE_DOC_STRING = """ + Examples: + ##### Image-to-Video + + ```python + >>> import torch + >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import SiglipImageProcessor, SiglipVisionModel + + >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 + ... ) + >>> feature_extractor = SiglipImageProcessor.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" + ... ) + >>> image_encoder = SiglipVisionModel.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 + ... ) + >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanVideo", + ... transformer=transformer, + ... feature_extractor=feature_extractor, + ... image_encoder=image_encoder, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" + ... ) + >>> output = pipe( + ... image=image, + ... prompt="A penguin dancing in the snow", + ... height=832, + ... width=480, + ... num_frames=91, + ... num_inference_steps=30, + ... guidance_scale=9.0, + ... generator=torch.Generator().manual_seed(0), + ... sampling_type="inverted_anti_drifting", + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` + + ##### First and Last Image-to-Video + + ```python + >>> import torch + >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import SiglipImageProcessor, SiglipVisionModel + + >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained( + ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16 + ... ) + >>> feature_extractor = SiglipImageProcessor.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor" + ... ) + >>> image_encoder = SiglipVisionModel.from_pretrained( + ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16 + ... ) + >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained( + ... "hunyuanvideo-community/HunyuanVideo", + ... transformer=transformer, + ... feature_extractor=feature_extractor, + ... image_encoder=image_encoder, + ... torch_dtype=torch.float16, + ... ) + >>> pipe.to("cuda") + + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + >>> first_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> output = pipe( + ... image=first_image, + ... last_image=last_image, + ... prompt=prompt, + ... height=512, + ... width=512, + ... num_frames=91, + ... num_inference_steps=30, + ... guidance_scale=9.0, + ... generator=torch.Generator().manual_seed(0), + ... sampling_type="inverted_anti_drifting", + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=30) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FramepackSamplingType(str, Enum): + VANILLA = "vanilla" + INVERTED_ANTI_DRIFTING = "inverted_anti_drifting" + + +class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoFramepackTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds + def _get_llama_prompt_embeds( + self, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def encode_image(self, image: torch.Tensor, device: torch.device | None = None, dtype: torch.dtype | None = None): + device = device or self._execution_device + image = (image + 1) / 2.0 # [-1, 1] -> [0, 1] + image = self.feature_extractor(images=image, return_tensors="pt", do_rescale=False).to( + device=device, dtype=self.image_encoder.dtype + ) + image_embeds = self.image_encoder(**image).last_hidden_state + return image_embeds.to(dtype=dtype) + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + image=None, + image_latents=None, + last_image=None, + last_image_latents=None, + sampling_type=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + sampling_types = [x.value for x in FramepackSamplingType.__members__.values()] + if sampling_type not in sampling_types: + raise ValueError(f"`sampling_type` has to be one of '{sampling_types}' but is '{sampling_type}'") + + if image is not None and image_latents is not None: + raise ValueError("Only one of `image` or `image_latents` can be passed.") + if last_image is not None and last_image_latents is not None: + raise ValueError("Only one of `last_image` or `last_image_latents` can be passed.") + if sampling_type != FramepackSamplingType.INVERTED_ANTI_DRIFTING and ( + last_image is not None or last_image_latents is not None + ): + raise ValueError( + 'Only `"inverted_anti_drifting"` inference type supports `last_image` or `last_image_latents`.' + ) + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_image_latents( + self, + image: torch.Tensor, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + if latents is None: + image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype) + latents = self.vae.encode(image).latent_dist.sample(generator=generator) + latents = latents * self.vae.config.scaling_factor + return latents.to(device=device, dtype=dtype) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + last_image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + latent_window_size: int = 9, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + image_latents: torch.Tensor | None = None, + last_image_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + sampling_type: FramepackSamplingType = FramepackSamplingType.INVERTED_ANTI_DRIFTING, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to be used as the starting point for the video generation. + last_image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`, *optional*): + The optional last image to be used as the ending point for the video generation. This is useful for + generating transitions between two images. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + image_latents (`torch.Tensor`, *optional*): + Pre-encoded image latents. If not provided, the image will be encoded using the VAE. + last_image_latents (`torch.Tensor`, *optional*): + Pre-encoded last image latents. If not provided, the last image will be encoded using the VAE. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoFramepackPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoFramepackPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoFramepackPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images and the second element is a list + of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) + content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + image, + image_latents, + last_image, + last_image_latents, + sampling_type, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + transformer_dtype = self.transformer.dtype + vae_dtype = self.vae.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare image + image = self.video_processor.preprocess(image, height, width) + image_embeds = self.encode_image(image, device=device).to(transformer_dtype) + if last_image is not None: + # Credits: https://github.com/lllyasviel/FramePack/pull/167 + # Users can modify the weighting strategy applied here + last_image = self.video_processor.preprocess(last_image, height, width) + last_image_embeds = self.encode_image(last_image, device=device).to(transformer_dtype) + last_image_embeds = (image_embeds + last_image_embeds) / 2 + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames) + history_video = None + total_generated_latent_frames = 0 + + image_latents = self.prepare_image_latents( + image, dtype=torch.float32, device=device, generator=generator, latents=image_latents + ) + if last_image is not None: + last_image_latents = self.prepare_image_latents( + last_image, dtype=torch.float32, device=device, generator=generator + ) + + # Specific to the released checkpoints: + # - https://huggingface.co/lllyasviel/FramePackI2V_HY + # - https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503 + # TODO: find a more generic way in future if there are more checkpoints + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + history_sizes = [1, 2, 16] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + + elif sampling_type == FramepackSamplingType.VANILLA: + history_sizes = [16, 2, 1] + history_latents = torch.zeros( + batch_size, + num_channels_latents, + sum(history_sizes), + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + history_latents = torch.cat([history_latents, image_latents], dim=2) + total_generated_latent_frames += 1 + + else: + assert False + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + for k in range(num_latent_sections): + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + latent_paddings = list(reversed(range(num_latent_sections))) + if num_latent_sections > 4: + latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0] + + is_first_section = k == 0 + is_last_section = k == num_latent_sections - 1 + latent_padding_size = latent_paddings[k] * latent_window_size + + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])) + ( + indices_prefix, + indices_padding, + indices_latents, + indices_latents_history_1x, + indices_latents_history_2x, + indices_latents_history_4x, + ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0) + # Inverted anti-drifting sampling: Figure 2(c) in the paper + indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_1x, latents_history_2x, latents_history_4x = history_latents[ + :, :, : sum(history_sizes) + ].split(history_sizes, dim=2) + if last_image is not None and is_first_section: + latents_history_1x = last_image_latents + latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2) + + elif sampling_type == FramepackSamplingType.VANILLA: + indices = torch.arange(0, sum([1, *history_sizes, latent_window_size])) + ( + indices_prefix, + indices_latents_history_4x, + indices_latents_history_2x, + indices_latents_history_1x, + indices_latents, + ) = indices.split([1, *history_sizes, latent_window_size], dim=0) + indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + + latents_prefix = image_latents + latents_history_4x, latents_history_2x, latents_history_1x = history_latents[ + :, :, -sum(history_sizes) : + ].split(history_sizes, dim=2) + latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2) + + else: + assert False + + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height, + width, + window_num_frames, + dtype=torch.float32, + device=device, + generator=generator, + latents=None, + ) + + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + image_seq_len = ( + latents.shape[2] * latents.shape[3] * latents.shape[4] / self.transformer.config.patch_size**2 + ) + exp_max = 7.0 + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + mu = min(mu, math.log(exp_max)) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + image_embeds=image_embeds, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x.to(transformer_dtype), + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x.to(transformer_dtype), + indices_latents_history_4x=indices_latents_history_4x, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + image_embeds=image_embeds, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x.to(transformer_dtype), + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x.to(transformer_dtype), + indices_latents_history_4x=indices_latents_history_4x, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + if is_last_section: + latents = torch.cat([image_latents, latents], dim=2) + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([latents, history_latents], dim=2) + real_history_latents = history_latents[:, :, :total_generated_latent_frames] + section_latent_frames = ( + (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2) + ) + index_slice = (slice(None), slice(None), slice(0, section_latent_frames)) + + elif sampling_type == FramepackSamplingType.VANILLA: + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + section_latent_frames = latent_window_size * 2 + index_slice = (slice(None), slice(None), slice(-section_latent_frames, None)) + + else: + assert False + + if history_video is None: + if not output_type == "latent": + current_latents = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor + history_video = self.vae.decode(current_latents, return_dict=False)[0] + else: + history_video = [real_history_latents] + else: + if not output_type == "latent": + overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1 + current_latents = ( + real_history_latents[index_slice].to(vae_dtype) / self.vae.config.scaling_factor + ) + current_video = self.vae.decode(current_latents, return_dict=False)[0] + + if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING: + history_video = self._soft_append(current_video, history_video, overlapped_frames) + elif sampling_type == FramepackSamplingType.VANILLA: + history_video = self._soft_append(history_video, current_video, overlapped_frames) + else: + assert False + else: + history_video.append(real_history_latents) + + self._current_timestep = None + + if not output_type == "latent": + generated_frames = history_video.size(2) + generated_frames = ( + generated_frames - 1 + ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + history_video = history_video[:, :, :generated_frames] + video = self.video_processor.postprocess_video(history_video, output_type=output_type) + else: + video = history_video + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoFramepackPipelineOutput(frames=video) + + def _soft_append(self, history: torch.Tensor, current: torch.Tensor, overlap: int = 0): + if overlap <= 0: + return torch.cat([history, current], dim=2) + + assert history.shape[2] >= overlap, f"Current length ({history.shape[2]}) must be >= overlap ({overlap})" + assert current.shape[2] >= overlap, f"History length ({current.shape[2]}) must be >= overlap ({overlap})" + + weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) + blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] + output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) + + return output.to(history) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..c599488c2379ee2ef436f74fae72efe52680f869 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -0,0 +1,1002 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + LlamaTokenizerFast, + LlavaForConditionalGeneration, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideoPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import load_image, export_to_video + + >>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch + >>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoImageToVideoPipeline.from_pretrained( + ... model_id, transformer=transformer, torch_dtype=torch.float16 + ... ) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> prompt = "A man with short gray hair plays a red electric guitar." + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" + ... ) + + >>> # If using hunyuanvideo-community/HunyuanVideo-I2V + >>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0] + + >>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch + >>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0] + + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ), + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, +} + + +def _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, +): + special_image_token_mask = text_input_ids == image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) + + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + expanded_input_ids = torch.full( + (text_input_ids.shape[0], max_expanded_length), + pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index + + expanded_attention_mask = torch.zeros( + (text_input_ids.shape[0], max_expanded_length), + dtype=prompt_attention_mask.dtype, + device=prompt_attention_mask.device, + ) + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) + + return { + "input_ids": expanded_input_ids, + "attention_mask": expanded_attention_mask, + "position_ids": position_ids, + } + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlavaForConditionalGeneration`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlavaForConditionalGeneration, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_processor=image_processor, + ) + + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + image: torch.Tensor, + prompt: str | list[str], + prompt_template: dict[str, Any], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + image_embed_interleave: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + + image_emb_len = prompt_template.get("image_emb_len", 576) + image_emb_start = prompt_template.get("image_emb_start", 5) + image_emb_end = prompt_template.get("image_emb_end", 581) + double_return_token_id = prompt_template.get("double_return_token_id", 271) + + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {} + crop_start -= 5 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + + image_token_index = self.text_encoder.config.image_token_index + pad_token_id = self.text_encoder.config.pad_token_id + expanded_inputs = _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, + ) + prompt_embeds = self.text_encoder( + **expanded_inputs, + pixel_values=image_embeds, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + text_crop_start = crop_start - 1 + image_emb_len + batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ + :, -1 + ] + batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 + assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + + prompt_embed_list = [] + prompt_attention_mask_list = [] + image_embed_list = [] + image_attention_mask_list = [] + + for i in range(text_input_ids.shape[0]): + prompt_embed_list.append( + torch.cat( + [ + prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()], + prompt_embeds[i, assistant_crop_end[i].item() :], + ] + ) + ) + prompt_attention_mask_list.append( + torch.cat( + [ + prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()], + prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :], + ] + ) + ) + image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end]) + image_attention_mask_list.append( + torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype) + ) + + prompt_embed_list = torch.stack(prompt_embed_list) + prompt_attention_mask_list = torch.stack(prompt_attention_mask_list) + image_embed_list = torch.stack(image_embed_list) + image_attention_mask_list = torch.stack(image_attention_mask_list) + + if 0 < image_embed_interleave < 6: + image_embed_list = image_embed_list[:, ::image_embed_interleave, :] + image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave] + + assert ( + prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0] + and image_embed_list.shape[0] == image_attention_mask_list.shape[0] + ) + + prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1) + prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + return prompt_embeds + + def encode_prompt( + self, + image: torch.Tensor, + prompt: str | list[str], + prompt_2: str | list[str] = None, + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + max_sequence_length: int = 256, + image_embed_interleave: int = 2, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + image, + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + true_cfg_scale=1.0, + guidance_scale=1.0, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + if true_cfg_scale > 1.0 and guidance_scale > 1.0: + logger.warning( + "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both " + "classifier-free guidance and embedded-guidance to be applied. This is not recommended " + "as it may lead to higher memory usage, slower inference and potentially worse results." + ) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + image_condition_type: str = "latent_concat", + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents * (1 - t) + + if image_condition_type == "token_replace": + image_latents = image_latents[:, :, :1] + + return latents, image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image, + prompt: str | list[str] = None, + prompt_2: str | list[str] = None, + negative_prompt: str | list[str] = None, + negative_prompt_2: str | list[str] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: list[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 1.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + prompt_template: dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + image_embed_interleave: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + true_cfg_scale, + guidance_scale, + ) + + image_condition_type = self.transformer.config.image_condition_type + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + image_embed_interleave = ( + image_embed_interleave + if image_embed_interleave is not None + else ( + 2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1 + ) + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Prepare latent variables + vae_dtype = self.vae.dtype + image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) + + if image_condition_type == "latent_concat": + num_channels_latents = (self.transformer.config.in_channels - 1) // 2 + elif image_condition_type == "token_replace": + num_channels_latents = self.transformer.config.in_channels + + latents, image_latents = self.prepare_latents( + image_tensor, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + image_condition_type, + ) + if image_condition_type == "latent_concat": + image_latents[:, :, 1:] = 0 + mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:]) + mask[:, :, 1:] = 0 + + # 4. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + image=image, + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + black_image = PIL.Image.new("RGB", (width, height), 0) + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + image=black_image, + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare guidance condition + guidance = None + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if image_condition_type == "latent_concat": + latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype) + elif image_condition_type == "token_replace": + latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + if image_condition_type == "latent_concat": + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + elif image_condition_type == "token_replace": + latents = latents = self.scheduler.step( + noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False + )[0] + latents = torch.cat([image_latents, latents], dim=2) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae_scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + if image_condition_type == "latent_concat": + video = video[:, :, 4:, :, :] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + if image_condition_type == "latent_concat": + video = latents[:, :, 1:, :, :] + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideoPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf96626277047d35d7d06e1481af6ee2fb56447 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video/pipeline_output.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +@dataclass +class HunyuanVideoFramepackPipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor + corresponds to a latent that decodes to multiple frames. + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] | list[torch.Tensor] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..846320f4ace07d502c0c8d067d63527a80664882 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] + _import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline + from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..82817365b6a529c2307540af8e3fbb268ad0b05a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -0,0 +1,103 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...configuration_utils import register_to_config +from ...video_processor import VideoProcessor + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20 +def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): + num_patches = round((base_size / patch_size) ** 2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 +def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): + """ + Get the closest ratio in the buckets. + + Args: + height (float): video height + width (float): video width + ratios (list): video aspect ratio + buckets (list): buckets generated by `generate_crop_size_list` + + Returns: + the closest size in the buckets and the corresponding ratio + """ + aspect_ratio = float(height) / float(width) + diff_ratios = ratios - aspect_ratio + + if aspect_ratio >= 1: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0] + else: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x >= 0] + + closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0] + closest_size = buckets[closest_ratio_id] + closest_ratio = ratios[closest_ratio_id] + + return closest_size, closest_ratio + + +class HunyuanVideo15ImageProcessor(VideoProcessor): + r""" + Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `32`): + VAE latent channels. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_convert_rgb=do_convert_rgb, + ) + + def calculate_default_height_width(self, height: int, width: int, target_size: int): + crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0] + + return height, width diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py new file mode 100644 index 0000000000000000000000000000000000000000..a0adff493ac0e7243f64a8763be74f4a5daeb9a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -0,0 +1,837 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any + +import numpy as np +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel + +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor +from .pipeline_output import HunyuanVideo15PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15Pipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v" + >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +def format_text_input(prompt: list[str], system_message: str) -> list[dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (list[str]): Input text. + system_message (str): System message. + + Returns: + list[dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +def extract_glyph_texts(prompt: str) -> list[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + self.default_aspect_ratio = (16, 9) # (width: height) + + @staticmethod + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 256, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if height is None and width is not None: + raise ValueError("If `width` is provided, `height` also have to be provided.") + elif width is None and height is not None: + raise ValueError("If `height` is provided, `width` also have to be provided.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_cond_latents_and_mask(self, latents, dtype: torch.dtype | None, device: torch.device | None): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + batch, channels, frames, height, width = latents.shape + + cond_latents_concat = torch.zeros(batch, channels, frames, height, width, dtype=dtype, device=device) + + mask_concat = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + + return cond_latents_concat, mask_concat + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: list[float] = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds_mask_2: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + height (`int`, *optional*): + The height in pixels of the generated video. + width (`int`, *optional*): + The width in pixels of the generated video. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + if height is None and width is None: + height, width = self.video_processor.calculate_default_height_width( + self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + self.num_channels_latents, + height, + width, + num_frames, + self.transformer.dtype, + device, + generator, + latents, + ) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device) + image_embeds = torch.zeros( + batch_size, + self.vision_num_semantic_tokens, + self.vision_states_dim, + dtype=self.transformer.dtype, + device=device, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # 8. decode the latents to video and postprocess + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..791dec0735248a12a4c68743084bded7663a3c25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -0,0 +1,960 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any + +import numpy as np +import PIL +import torch +from transformers import ( + ByT5Tokenizer, + Qwen2_5_VLTextModel, + Qwen2Tokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, +) + +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor +from .pipeline_output import HunyuanVideo15PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15ImageToVideoPipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_i2v" + >>> pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG") + + >>> output = pipe( + ... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + ... image=image, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input +def format_text_input(prompt: list[str], system_message: str) -> list[dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (list[str]): Input text. + system_message (str): System message. + + Returns: + list[dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts +def extract_glyph_texts(prompt: str) -> list[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + image_encoder ([`SiglipVisionModel`]): + [SiglipVisionModel](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipVisionModel) + variant. + feature_extractor ([`SiglipImageProcessor`]): + [SiglipImageProcessor](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipImageProcessor) + variant. + """ + + model_cpu_offload_seq = "image_encoder->text_encoder->transformer->vae" + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True + ) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: str | list[str], + device: torch.device, + tokenizer_max_length: int = 256, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + @staticmethod + def _get_image_latents( + vae: AutoencoderKLHunyuanVideo15, + image_processor: HunyuanVideo15ImageProcessor, + image: PIL.Image.Image, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + vae_dtype = vae.dtype + image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) + image_tensor = image_tensor.unsqueeze(2) + image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax") + image_latents = image_latents * vae.config.scaling_factor + return image_latents + + @staticmethod + def _get_image_embeds( + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image: PIL.Image.Image, + device: torch.device, + ) -> torch.Tensor: + image_encoder_dtype = next(image_encoder.parameters()).dtype + image = feature_extractor.preprocess(images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True) + image = image.to(device=device, dtype=image_encoder_dtype) + image_enc_hidden_states = image_encoder(**image).last_hidden_state + + return image_enc_hidden_states + + def encode_image( + self, + image: PIL.Image.Image, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + image_embeds = self._get_image_embeds( + image_encoder=self.image_encoder, + feature_extractor=self.feature_extractor, + image=image, + device=device, + ) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(device=device, dtype=dtype) + return image_embeds + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + image: PIL.Image.Image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_cond_latents_and_mask( + self, + latents: torch.Tensor, + image: PIL.Image.Image, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + ): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + + batch, channels, frames, height, width = latents.shape + + image_latents = self._get_image_latents( + vae=self.vae, + image_processor=self.video_processor, + image=image, + height=height, + width=width, + device=device, + ) + + latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1) + latent_condition[:, :, 1:, :, :] = 0 + latent_condition = latent_condition.to(device=device, dtype=dtype) + + latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + latent_mask[:, :, 0, :, :] = 1.0 + + return latent_condition, latent_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: list[float] = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + prompt_embeds_mask_2: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds_mask_2: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`): + The input image to condition video generation on. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + height, width = self.video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=self.target_size + ) + image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop") + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode image + image_embeds = self.encode_image( + image=image, + batch_size=batch_size * num_videos_per_prompt, + device=device, + dtype=self.transformer.dtype, + ) + + # 4. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare latent variables + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask( + latents=latents, + image=image, + batch_size=batch_size * num_videos_per_prompt, + height=height, + width=width, + dtype=self.transformer.dtype, + device=device, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + if self.transformer.config.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + else: + timestep_r = None + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + timestep_r=timestep_r, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..ab8767299e55c961bed4d8394e01154029c00544 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideo15PipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo1.5 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuandit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuandit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8337399106f0585d15fa0a35f607baa2c04b203b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuandit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuandit"] = ["HunyuanDiTPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuandit import HunyuanDiTPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py new file mode 100644 index 0000000000000000000000000000000000000000..b908dd5dfe8307fba0972ac80ca19530ddbae7c6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -0,0 +1,908 @@ +# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import HunyuanDiTPipeline + + >>> pipe = HunyuanDiTPipeline.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> # You may also use English prompt as HunyuanDiT supports both English and Chinese + >>> # prompt = "An astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPipeline(DiffusionPipeline): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (`~transformers.BertModel`, `~transformers.CLIPTextModel` | None): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer` | None): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + text_encoder_2: T5EncoderModel | None = None, + tokenizer_2: T5Tokenizer | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int | None = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + prompt_attention_mask_2: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask_2: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = (1024, 1024), + target_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/i2vgen_xl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/i2vgen_xl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b24a7e4cee7fb843b9424469a05f511adfa758de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/i2vgen_xl/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_i2vgen_xl"] = ["I2VGenXLPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_i2vgen_xl import I2VGenXLPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..731ac27a0ff56b048ea53d9ec334377d7e6cee20 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -0,0 +1,797 @@ +# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL +from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet +from ...schedulers import DDIMScheduler +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import I2VGenXLPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> pipeline = I2VGenXLPipeline.from_pretrained( + ... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipeline.enable_model_cpu_offload() + + >>> image_url = ( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png" + ... ) + >>> image = load_image(image_url).convert("RGB") + + >>> prompt = "Papers were floating in the air on a table in the library" + >>> negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms" + >>> generator = torch.manual_seed(8888) + + >>> frames = pipeline( + ... prompt=prompt, + ... image=image, + ... num_inference_steps=50, + ... negative_prompt=negative_prompt, + ... guidance_scale=9.0, + ... generator=generator, + ... ).frames[0] + >>> video_path = export_to_gif(frames, "i2v.gif") + ``` +""" + + +@dataclass +class I2VGenXLPipelineOutput(BaseOutput): + r""" + Output class for image-to-video pipeline. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] + + +class I2VGenXLPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, +): + _last_supported_version = "0.33.1" + r""" + Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`I2VGenXLUNet`]): + A [`I2VGenXLUNet`] to denoise the encoded video latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + image_encoder: CLIPVisionModelWithProjection, + feature_extractor: CLIPImageProcessor, + unet: I2VGenXLUNet, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # `do_resize=False` as we do custom resizing. + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def encode_prompt( + self, + prompt, + device, + num_videos_per_prompt, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + # Apply clip_skip to negative prompt embeds + if clip_skip is None: + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + else: + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + negative_prompt_embeds = negative_prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + negative_prompt_embeds = self.text_encoder.text_model.final_layer_norm(negative_prompt_embeds) + + if self.do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def _encode_image(self, image, device, num_videos_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) + + # Normalize the image with CLIP training stats. + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if self.do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def decode_latents(self, latents, decode_chunk_size=None): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + if decode_chunk_size is not None: + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + frame = self.vae.decode(latents[i : i + decode_chunk_size]).sample + frames.append(frame) + image = torch.cat(frames, dim=0) + else: + image = self.vae.decode(latents).sample + + decode_shape = (batch_size, num_frames, -1) + image.shape[2:] + video = image[None, :].reshape(decode_shape).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + def prepare_image_latents( + self, + image, + device, + num_frames, + num_videos_per_prompt, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.sample() + image_latents = image_latents * self.vae.config.scaling_factor + + # Add frames dimension to image latents + image_latents = image_latents.unsqueeze(2) + + # Append a position mask for each subsequent frame + # after the initial image latent frame + frame_position_mask = [] + for frame_idx in range(num_frames - 1): + scale = (frame_idx + 1) / (num_frames - 1) + frame_position_mask.append(torch.ones_like(image_latents[:, :, :1]) * scale) + if frame_position_mask: + frame_position_mask = torch.cat(frame_position_mask, dim=2) + image_latents = torch.cat([image_latents, frame_position_mask], dim=2) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1, 1) + + if self.do_classifier_free_guidance: + image_latents = torch.cat([image_latents] * 2) + + return image_latents + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = 704, + width: int | None = 1280, + target_fps: int | None = 16, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: str | list[str] | None = None, + eta: float = 0.0, + num_videos_per_prompt: int | None = 1, + decode_chunk_size: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = 1, + ): + r""" + The call function to the pipeline for image-to-video generation with [`I2VGenXLPipeline`]. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.Tensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + target_fps (`int`, *optional*): + Frames per second. The rate at which the generated images shall be exported to a video after + generation. This is also used as a "micro-condition" while generation. + num_frames (`int`, *optional*): + The number of video frames to generate. + num_inference_steps (`int`, *optional*): + The number of denoising steps. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + num_videos_per_prompt (`int`, *optional*): + The number of images to generate per prompt. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal + consistency between frames, but also the higher the memory consumption. By default, the decoder will + decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = guidance_scale + + # 3.1 Encode input text prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 3.2 Encode image prompt + # 3.2.1 Image encodings. + # https://github.com/ali-vilab/i2vgen-xl/blob/2539c9262ff8a2a22fa9daecbfd13f0a2dbc32d0/tools/inferences/inference_i2vgen_entrance.py#L114 + cropped_image = _center_crop_wide(image, (width, width)) + cropped_image = _resize_bilinear( + cropped_image, (self.feature_extractor.crop_size["width"], self.feature_extractor.crop_size["height"]) + ) + image_embeddings = self._encode_image(cropped_image, device, num_videos_per_prompt) + + # 3.2.2 Image latents. + resized_image = _center_crop_wide(image, (width, height)) + image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) + image_latents = self.prepare_image_latents( + image, + device=device, + num_frames=num_frames, + num_videos_per_prompt=num_videos_per_prompt, + ) + + # 3.3 Prepare additional conditions for the UNet. + if self.do_classifier_free_guidance: + fps_tensor = torch.tensor([target_fps, target_fps]).to(device) + else: + fps_tensor = torch.tensor([target_fps]).to(device) + fps_tensor = fps_tensor.repeat(batch_size * num_videos_per_prompt, 1).ravel() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + fps=fps_tensor, + image_latents=image_latents, + image_embeddings=image_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + batch_size, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(batch_size, frames, channel, width, height).permute(0, 2, 1, 3, 4) + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return I2VGenXLPipelineOutput(frames=video) + + +# The following utilities are taken and adapted from +# https://github.com/ali-vilab/i2vgen-xl/blob/main/utils/transforms.py. + + +def _convert_pt_to_pil(image: torch.Tensor | list[torch.Tensor]): + if isinstance(image, list) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, 0) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.unsqueeze(0) + + image_numpy = VaeImageProcessor.pt_to_numpy(image) + image_pil = VaeImageProcessor.numpy_to_pil(image_numpy) + image = image_pil + + return image + + +def _resize_bilinear( + image: torch.Tensor | list[torch.Tensor] | PIL.Image.Image | list[PIL.Image.Image], resolution: tuple[int, int] +): + # First convert the images to PIL in case they are float tensors (only relevant for tests now). + image = _convert_pt_to_pil(image) + + if isinstance(image, list): + image = [u.resize(resolution, PIL.Image.BILINEAR) for u in image] + else: + image = image.resize(resolution, PIL.Image.BILINEAR) + return image + + +def _center_crop_wide( + image: torch.Tensor | list[torch.Tensor] | PIL.Image.Image | list[PIL.Image.Image], resolution: tuple[int, int] +): + # First convert the images to PIL in case they are float tensors (only relevant for tests now). + image = _convert_pt_to_pil(image) + + if isinstance(image, list): + scale = min(image[0].size[0] / resolution[0], image[0].size[1] / resolution[1]) + image = [u.resize((round(u.width // scale), round(u.height // scale)), resample=PIL.Image.BOX) for u in image] + + # center crop + x1 = (image[0].width - resolution[0]) // 2 + y1 = (image[0].height - resolution[1]) // 2 + image = [u.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) for u in image] + return image + else: + scale = min(image.size[0] / resolution[0], image.size[1] / resolution[1]) + image = image.resize((round(image.width // scale), round(image.height // scale)), resample=PIL.Image.BOX) + x1 = (image.width - resolution[0]) // 2 + y1 = (image.height - resolution[1]) // 2 + image = image.crop((x1, y1, x1 + resolution[0], y1 + resolution[1])) + return image diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..606f7b378a79489bbcbaa87db2040bd4196bbd8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/__init__.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["KandinskyPipeline"] + _import_structure["pipeline_kandinsky_combined"] = [ + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyInpaintCombinedPipeline", + ] + _import_structure["pipeline_kandinsky_img2img"] = ["KandinskyImg2ImgPipeline"] + _import_structure["pipeline_kandinsky_inpaint"] = ["KandinskyInpaintPipeline"] + _import_structure["pipeline_kandinsky_prior"] = ["KandinskyPriorPipeline", "KandinskyPriorPipelineOutput"] + _import_structure["text_encoder"] = ["MultilingualCLIP"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_kandinsky import KandinskyPipeline + from .pipeline_kandinsky_combined import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyInpaintCombinedPipeline, + ) + from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline + from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline + from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput + from .text_encoder import MultilingualCLIP + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky.py new file mode 100644 index 0000000000000000000000000000000000000000..d635057f2b05319058529ac815ceb41df6babdcd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -0,0 +1,415 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler, DDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_encoder import MultilingualCLIP + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +class KandinskyPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=77, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image_embeds: torch.Tensor | list[torch.Tensor], + negative_image_embeds: torch.Tensor | list[torch.Tensor], + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..21d5c98efa905139cf5e2c4aaabe8899666bb75c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -0,0 +1,817 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, + XLMRobertaTokenizer, +) + +from ...models import PriorTransformer, UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler, DDPMScheduler, UnCLIPScheduler +from ...utils import ( + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from .pipeline_kandinsky import KandinskyPipeline +from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline +from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline +from .pipeline_kandinsky_prior import KandinskyPriorPipeline +from .text_encoder import MultilingualCLIP + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipe = AutoPipelineForText2Image.from_pretrained( + "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + image = pipe(prompt=prompt, num_inference_steps=25).images[0] + ``` +""" + +IMAGE2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForImage2Image + import torch + import requests + from io import BytesIO + from PIL import Image + import os + + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + image.thumbnail((768, 768)) + + image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0] + ``` +""" + +INPAINT_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForInpainting + from diffusers.utils import load_image + import torch + import numpy as np + + pipe = AutoPipelineForInpainting.from_pretrained( + "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + original_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + ) + + mask = np.zeros((768, 768), dtype=np.float32) + # Let's mask out an area above the cat's head + mask[:250, 250:-250] = 1 + + image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0] + ``` +""" + + +class KandinskyCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _load_connected_pipes = True + model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder" + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyPriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + outputs = self.decoder_pipe( + prompt=prompt, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + ) + + self.maybe_free_model_hooks() + + return outputs + + +class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _load_connected_pipes = True + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq" + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyPriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyImg2ImgPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + strength: float = 0.3, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(prompt, PIL.Image.Image) else image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + outputs = self.decoder_pipe( + prompt=prompt, + image=image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + strength=strength, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + ) + + self.maybe_free_model_hooks() + + return outputs + + +class KandinskyInpaintCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _load_connected_pipes = True + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq" + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyPriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyInpaintPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + mask_image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while + black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(prompt, PIL.Image.Image) else image + mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + if ( + isinstance(mask_image, (list, tuple)) + and len(mask_image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(mask_image) == 0 + ): + mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image + + outputs = self.decoder_pipe( + prompt=prompt, + image=image, + mask_image=mask_image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + ) + + self.maybe_free_model_hooks() + + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..f33317447b49a264b104bbbf54c2e69d49ee9759 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -0,0 +1,501 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import PIL.Image +import torch +from transformers import ( + XLMRobertaTokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_encoder import MultilingualCLIP + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... prompt, + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +class KandinskyImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = ( + 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8 + ) + movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.movq_scale_factor, + vae_latent_channels=movq_latent_channels, + resample="bicubic", + reducing_gap=1, + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + + shape = latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + latents = self.add_noise(latents, noise, latent_timestep) + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + + return noisy_samples + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + image_embeds: torch.Tensor, + negative_image_embeds: torch.Tensor, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + strength: float = 0.3, + guidance_scale: float = 7.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. get text and image embeddings + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # 3. pre-processing initial image + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + # the formular to calculate timestep for add_noise is taken from the original kandinsky repo + latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 + + latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # 5. Create initial latent + latents = self.prepare_latents( + latents, + latent_timestep, + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + self.scheduler, + ) + + # 6. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 7. post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + image = self.image_processor.postprocess(image, output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..4dba85446db5c8bed9633278fb186874624e7a2a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -0,0 +1,643 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from packaging import version +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ... import __version__ +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDIMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .text_encoder import MultilingualCLIP + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyInpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.zeros((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 1 + + >>> out = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (np.array | PIL.Image | torch.Tensor): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + mask = 1 - mask + + return mask, image + + +class KandinskyInpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + movq=movq, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + self._warn_has_been_called = False + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | PIL.Image.Image, + mask_image: torch.Tensor | PIL.Image.Image | np.ndarray, + image_embeds: torch.Tensor, + negative_image_embeds: torch.Tensor, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image` or `np.ndarray`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`,`torch.Tensor` or `np.ndarray`): + `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the + image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the + expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL + image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it + will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected + shape is `(H, W)`. + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse( + "0.23.0.dev0" + ): + logger.warning( + "Please note that the expected format of `mask_image` has recently been changed. " + "Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. " + "As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. " + "This way, Kandinsky's masking behavior is aligned with Stable Diffusion. " + "THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. " + "This warning will be suppressed after the first inference call and will be removed in diffusers>0.23.0" + ) + self._warn_has_been_called = True + + # Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=prompt_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + # get h, w for latents + sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, sample_height, sample_width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # Check that sizes of mask, masked image and latents match with expected + num_channels_mask = mask_image.shape[1] + num_channels_masked_image = masked_image.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..523fd010eb7f3f4e5535c1da2e4206b700d1e498 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -0,0 +1,558 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...schedulers import UnCLIPScheduler +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) + >>> pipe.to("cuda") + + >>> image = pipe( + ... "", + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +@dataclass +class KandinskyPriorPipelineOutput(BaseOutput): + """ + Output class for KandinskyPriorPipeline. + + Args: + image_embeds (`torch.Tensor`) + clip image embeddings for text prompt + negative_image_embeds (`list[PIL.Image.Image]` or `np.ndarray`) + clip image embeddings for unconditional tokens + """ + + image_embeds: torch.Tensor | np.ndarray + negative_image_embeds: torch.Tensor | np.ndarray + + +class KandinskyPriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + _exclude_from_cpu_offload = ["prior"] + model_cpu_offload_seq = "text_encoder->prior" + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: list[str | PIL.Image.Image | torch.Tensor], + weights: list[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + negative_prior_prompt: str | None = None, + negative_prompt: str = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`list[str | PIL.Image.Image | torch.Tensor]`): + list of prompts and images to guide the image generation. + weights: (`list[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"] + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + guidance_scale: float = 4.0, + output_type: str | None = "pt", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + + self.maybe_free_model_hooks() + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.prior_hook.offload() + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/text_encoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..58cc9ac4d3ed0fd7f17c074d835d96930b195d3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/text_encoder.py @@ -0,0 +1,29 @@ +import torch +from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel + + +class MCLIPConfig(XLMRobertaConfig): + model_type = "M-CLIP" + + def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): + self.transformerDimensions = transformerDimSize + self.numDims = imageDimSize + super().__init__(**kwargs) + + +class MultilingualCLIP(PreTrainedModel): + config_class = MCLIPConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.transformer = XLMRobertaModel(config) + self.LinearTransformation = torch.nn.Linear( + in_features=config.transformerDimensions, out_features=config.numDims + ) + if hasattr(self, "post_init"): + self.post_init() + + def forward(self, input_ids, attention_mask): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] + return self.LinearTransformation(embs2), embs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67e97f161173ac8981dadf757fd8d6438307c973 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/__init__.py @@ -0,0 +1,70 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky2_2"] = ["KandinskyV22Pipeline"] + _import_structure["pipeline_kandinsky2_2_combined"] = [ + "KandinskyV22CombinedPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22InpaintCombinedPipeline", + ] + _import_structure["pipeline_kandinsky2_2_controlnet"] = ["KandinskyV22ControlnetPipeline"] + _import_structure["pipeline_kandinsky2_2_controlnet_img2img"] = ["KandinskyV22ControlnetImg2ImgPipeline"] + _import_structure["pipeline_kandinsky2_2_img2img"] = ["KandinskyV22Img2ImgPipeline"] + _import_structure["pipeline_kandinsky2_2_inpainting"] = ["KandinskyV22InpaintPipeline"] + _import_structure["pipeline_kandinsky2_2_prior"] = ["KandinskyV22PriorPipeline"] + _import_structure["pipeline_kandinsky2_2_prior_emb2emb"] = ["KandinskyV22PriorEmb2EmbPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky2_2 import KandinskyV22Pipeline + from .pipeline_kandinsky2_2_combined import ( + KandinskyV22CombinedPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22InpaintCombinedPipeline, + ) + from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline + from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline + from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline + from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline + from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline + from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py new file mode 100644 index 0000000000000000000000000000000000000000..5129b3f548e89c8b8a0b146df6414b37b7b5b138 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -0,0 +1,331 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") + >>> pipe_prior.to("cuda") + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> zero_image_emb = out.negative_image_embeds + >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images + >>> image[0].save("cat.png") + ``` +""" + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class KandinskyV22Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeds: torch.Tensor | list[torch.Tensor], + negative_image_embeds: torch.Tensor | list[torch.Tensor], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] * num_images_per_prompt + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if self.do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds = callback_outputs.pop("image_embeds", image_embeds) + negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if not output_type == "latent": + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..26e163a7014233e179db62c359b4f1e4ec69c991 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -0,0 +1,854 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer, UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler, UnCLIPScheduler +from ...utils import deprecate, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from .pipeline_kandinsky2_2 import KandinskyV22Pipeline +from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline +from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline +from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForText2Image + import torch + + pipe = AutoPipelineForText2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + image = pipe(prompt=prompt, num_inference_steps=25).images[0] + ``` +""" + +IMAGE2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForImage2Image + import torch + import requests + from io import BytesIO + from PIL import Image + import os + + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + image.thumbnail((768, 768)) + + image = pipe(prompt=prompt, image=original_image, num_inference_steps=25).images[0] + ``` +""" + +INPAINT_EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import AutoPipelineForInpainting + from diffusers.utils import load_image + import torch + import numpy as np + + pipe = AutoPipelineForInpainting.from_pretrained( + "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + + prompt = "A fantasy landscape, Cinematic lighting" + negative_prompt = "low quality, bad quality" + + original_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + ) + + mask = np.zeros((768, 768), dtype=np.float32) + # Let's mask out an area above the cat's head + mask[:250, 250:-250] = 1 + + image = pipe(prompt=prompt, image=original_image, mask_image=mask, num_inference_steps=25).images[0] + ``` +""" + + +class KandinskyV22CombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + prior_image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" + _load_connected_pipes = True + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyV22PriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyV22Pipeline( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + prior_callback_on_step_end: Callable[[int, int], None] | None = None, + prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference of the prior pipeline. + The function is called with the following arguments: `prior_callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your prior pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference of the decoder pipeline. + The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, + step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors + as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + outputs = self.decoder_pipe( + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + self.maybe_free_model_hooks() + + return outputs + + +class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + prior_image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" + _load_connected_pipes = True + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyV22PriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyV22Img2ImgPipeline( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(IMAGE2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + prior_callback_on_step_end: Callable[[int, int], None] | None = None, + prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(image, PIL.Image.Image) else image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + outputs = self.decoder_pipe( + image=image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + strength=strength, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + callback=callback, + callback_steps=callback_steps, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self.maybe_free_model_hooks() + return outputs + + +class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline): + """ + Combined Pipeline for inpainting generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler (`DDIMScheduler` | `DDPMScheduler`): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + prior_prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + prior_tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior_scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + prior_image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq" + _load_connected_pipes = True + _exclude_from_cpu_offload = ["prior_prior"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + prior_prior: PriorTransformer, + prior_image_encoder: CLIPVisionModelWithProjection, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: UnCLIPScheduler, + prior_image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + prior_prior=prior_prior, + prior_image_encoder=prior_image_encoder, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + prior_image_processor=prior_image_processor, + ) + self.prior_pipe = KandinskyV22PriorPipeline( + prior=prior_prior, + image_encoder=prior_image_encoder, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_processor=prior_image_processor, + ) + self.decoder_pipe = KandinskyV22InpaintPipeline( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.enable_model_cpu_offload() + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(INPAINT_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + mask_image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + negative_prompt: str | list[str] | None = None, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + prior_guidance_scale: float = 4.0, + prior_num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Callable[[int, int], None] | None = None, + prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while + black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + prior_kwargs = {} + if kwargs.get("prior_callback", None) is not None: + prior_kwargs["callback"] = kwargs.pop("prior_callback") + deprecate( + "prior_callback", + "1.0.0", + "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + if kwargs.get("prior_callback_steps", None) is not None: + deprecate( + "prior_callback_steps", + "1.0.0", + "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") + + prior_outputs = self.prior_pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + num_inference_steps=prior_num_inference_steps, + generator=generator, + latents=latents, + guidance_scale=prior_guidance_scale, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + **prior_kwargs, + ) + image_embeds = prior_outputs[0] + negative_image_embeds = prior_outputs[1] + + prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt + image = [image] if isinstance(image, PIL.Image.Image) else image + mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image + + if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0: + prompt = (image_embeds.shape[0] // len(prompt)) * prompt + + if ( + isinstance(image, (list, tuple)) + and len(image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(image) == 0 + ): + image = (image_embeds.shape[0] // len(image)) * image + + if ( + isinstance(mask_image, (list, tuple)) + and len(mask_image) < image_embeds.shape[0] + and image_embeds.shape[0] % len(mask_image) == 0 + ): + mask_image = (image_embeds.shape[0] // len(mask_image)) * mask_image + + outputs = self.decoder_pipe( + image=image, + mask_image=mask_image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + width=width, + height=height, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + ) + self.maybe_free_model_hooks() + + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..01001e0c9eba8ca1f439f539b3df2d11c54440ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py @@ -0,0 +1,333 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch + +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import numpy as np + + >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline + >>> from transformers import pipeline + >>> from diffusers.utils import load_image + + + >>> def make_hint(image, depth_estimator): + ... image = depth_estimator(image)["depth"] + ... image = np.array(image) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... detected_map = torch.from_numpy(image).float() / 255.0 + ... hint = detected_map.permute(2, 0, 1) + ... return hint + + + >>> depth_estimator = pipeline("depth-estimation") + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior = pipe_prior.to("cuda") + + >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((768, 768)) + + >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") + + >>> prompt = "A robot, 4k photo" + >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" + + >>> generator = torch.Generator(device="cuda").manual_seed(43) + + >>> image_emb, zero_image_emb = pipe_prior( + ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator + ... ).to_tuple() + + >>> images = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... hint=hint, + ... num_inference_steps=50, + ... generator=generator, + ... height=768, + ... width=768, + ... ).images + + >>> images[0].save("robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class KandinskyV22ControlnetPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image_embeds: torch.Tensor | list[torch.Tensor], + negative_image_embeds: torch.Tensor | list[torch.Tensor], + hint: torch.Tensor, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + hint (`torch.Tensor`): + The controlnet condition. + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + if isinstance(hint, list): + hint = torch.cat(hint, dim=0) + + batch_size = image_embeds.shape[0] * num_images_per_prompt + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hint = hint.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + # Offload all models + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..891235558330c9de3c6f30edfc4bc83573fbd473 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py @@ -0,0 +1,371 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import PIL.Image +import torch + +from ...image_processor import VaeImageProcessor +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + is_torch_xla_available, + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import numpy as np + + >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline + >>> from transformers import pipeline + >>> from diffusers.utils import load_image + + + >>> def make_hint(image, depth_estimator): + ... image = depth_estimator(image)["depth"] + ... image = np.array(image) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... detected_map = torch.from_numpy(image).float() / 255.0 + ... hint = detected_map.permute(2, 0, 1) + ... return hint + + + >>> depth_estimator = pipeline("depth-estimation") + + >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior = pipe_prior.to("cuda") + + >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((768, 768)) + + + >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") + + >>> prompt = "A robot, 4k photo" + >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" + + >>> generator = torch.Generator(device="cuda").manual_seed(43) + + >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator) + >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator) + + >>> images = pipe( + ... image=img, + ... strength=0.5, + ... image_embeds=img_emb.image_embeds, + ... negative_image_embeds=negative_emb.image_embeds, + ... hint=hint, + ... num_inference_steps=50, + ... generator=generator, + ... height=768, + ... width=768, + ... ).images + + >>> images[0].save("robot_cat.png") + ``` +""" + + +class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8 + movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4 + self.image_processor = VaeImageProcessor( + vae_scale_factor=movq_scale_factor, + vae_latent_channels=movq_latent_channels, + resample="bicubic", + reducing_gap=1, + ) + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @torch.no_grad() + def __call__( + self, + image_embeds: torch.Tensor | list[torch.Tensor], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + negative_image_embeds: torch.Tensor | list[torch.Tensor], + hint: torch.Tensor, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + hint (`torch.Tensor`): + The controlnet condition. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + if isinstance(hint, list): + hint = torch.cat(hint, dim=0) + + batch_size = image_embeds.shape[0] + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hint = hint.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) + + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0) + image = image.to(dtype=image_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator + ) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + # Offload all models + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + image = self.image_processor.postprocess(image, output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..a43c085a89210ca2df618433852661b8bb36db8b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py @@ -0,0 +1,382 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import PIL.Image +import torch + +from ...image_processor import VaeImageProcessor +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +class KandinskyV22Img2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8 + movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4 + self.image_processor = VaeImageProcessor( + vae_scale_factor=movq_scale_factor, + vae_latent_channels=movq_latent_channels, + resample="bicubic", + reducing_gap=1, + ) + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image_embeds: torch.Tensor | list[torch.Tensor], + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image], + negative_image_embeds: torch.Tensor | list[torch.Tensor], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + strength: float = 0.3, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if self.do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0) + image = image.to(dtype=image_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator + ) + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds = callback_outputs.pop("image_embeds", image_embeds) + negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + image = self.image_processor.postprocess(image, output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..31bd88103a06475968897f0722922b651b3e0500 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py @@ -0,0 +1,567 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from packaging import version +from PIL import Image + +from ... import __version__ +from ...models import UNet2DConditionModel, VQModel +from ...schedulers import DDPMScheduler +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyV22InpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.zeros((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 1 + + >>> out = pipe( + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (np.array | PIL.Image | torch.Tensor): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + mask = 1 - mask + + return mask, image + + +class KandinskyV22InpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + model_cpu_offload_seq = "unet->movq" + _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds", "masked_image", "mask_image"] + + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + self._warn_has_been_called = False + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image_embeds: torch.Tensor | list[torch.Tensor], + image: torch.Tensor | PIL.Image.Image, + mask_image: torch.Tensor | PIL.Image.Image | np.ndarray, + negative_image_embeds: torch.Tensor | list[torch.Tensor], + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`np.array`): + Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while + black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single + channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, + so the expected shape would be `(B, H, W, 1)`. + negative_image_embeds (`torch.Tensor` or `list[torch.Tensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + if not self._warn_has_been_called and version.parse(version.parse(__version__).base_version) < version.parse( + "0.23.0.dev0" + ): + logger.warning( + "Please note that the expected format of `mask_image` has recently been changed. " + "Before diffusers == 0.19.0, Kandinsky Inpainting pipelines repainted black pixels and preserved black pixels. " + "As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. " + "This way, Kandinsky's masking behavior is aligned with Stable Diffusion. " + "THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. " + "This warning will be suppressed after the first inference call and will be removed in diffusers>0.23.0" + ) + self._warn_has_been_called = True + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + self._guidance_scale = guidance_scale + + device = self._execution_device + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + batch_size = image_embeds.shape[0] * num_images_per_prompt + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if self.do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=self.unet.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=image_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=image_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if self.do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + num_channels_latents = self.movq.config.latent_channels + + height, width = downscale_height_and_width(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + noise = torch.clone(latents) + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + )[0] + init_latents_proper = image[:1] + init_mask = mask_image[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = init_mask * init_latents_proper + (1 - init_mask) * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeds = callback_outputs.pop("image_embeds", image_embeds) + negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds) + masked_image = callback_outputs.pop("masked_image", masked_image) + mask_image = callback_outputs.pop("mask_image", mask_image) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..41f4474c3906a83ca4f370558d2bf715cef663d1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -0,0 +1,557 @@ +from typing import Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...schedulers import UnCLIPScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..kandinsky import KandinskyPriorPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") + >>> pipe_prior.to("cuda") + >>> prompt = "red cat, 4k photo" + >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple() + + >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline + >>> from diffusers.utils import load_image + >>> import PIL + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> out = pipe_prior.interpolate(images_texts, weights) + >>> pipe = KandinskyV22Pipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> image = pipe( + ... image_embeds=out.image_embeds, + ... negative_image_embeds=out.negative_image_embeds, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ).images[0] + >>> image.save("starry_cat.png") + ``` +""" + + +class KandinskyV22PriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + image_processor ([`CLIPImageProcessor`]): + A image_processor to be used to preprocess image from clip. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->prior" + _exclude_from_cpu_offload = ["prior"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "text_encoder_hidden_states", "text_mask"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: list[str | PIL.Image.Image | torch.Tensor], + weights: list[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + negative_prior_prompt: str | None = None, + negative_prompt: str = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`list[str | PIL.Image.Image | torch.Tensor]`): + list of prompts and images to guide the image generation. + weights: (`list[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds.unsqueeze(0) + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0) + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + guidance_scale: float = 4.0, + output_type: str | None = "pt", # pt only + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + self._guidance_scale = guidance_scale + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if self.do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + self.guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == timesteps.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + text_mask = callback_outputs.pop("text_mask", text_mask) + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py new file mode 100644 index 0000000000000000000000000000000000000000..adbc3a5badc5cb30647fcdde0a1ebf6f7c4b86fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py @@ -0,0 +1,569 @@ +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...schedulers import UnCLIPScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..kandinsky import KandinskyPriorPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> img = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() + + >>> pipe = KandinskyPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16" + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyV22Pipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->prior" + _exclude_from_cpu_offload = ["prior"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: list[str | PIL.Image.Image | torch.Tensor], + weights: list[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + negative_prior_prompt: str | None = None, + negative_prompt: str = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`list[str | PIL.Image.Image | torch.Tensor]`): + list of prompts and images to guide the image generation. + weights: (`list[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds.unsqueeze(0) + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + image_emb = self._encode_image( + cond, device=device, num_images_per_prompt=num_images_per_prompt + ).unsqueeze(0) + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0) + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb)) + + def _encode_image( + self, + image: torch.Tensor | list[PIL.Image.Image], + device, + num_images_per_prompt, + ): + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values.to( + dtype=self.image_encoder.dtype, device=device + ) + + image_emb = self.image_encoder(image)["image_embeds"] # B, D + image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0) + image_emb.to(device=device) + + return image_emb + + def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + emb = emb.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + init_latents = emb + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + image: torch.Tensor | list[torch.Tensor] | PIL.Image.Image | list[PIL.Image.Image], + strength: float = 0.3, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + guidance_scale: float = 4.0, + output_type: str | None = "pt", # pt only + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. + emb (`torch.Tensor`): + The image embedding. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if not isinstance(image, list): + image = [image] + + if isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + if isinstance(image, torch.Tensor) and image.ndim == 2: + # allow user to pass image_embeds directly + image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0) + elif isinstance(image, torch.Tensor) and image.ndim != 4: + raise ValueError( + f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}" + ) + else: + image_embeds = self._encode_image(image, device, num_images_per_prompt) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + + latents = image_embeds + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size) + latents = self.prepare_latents( + latents, + latent_timestep, + batch_size // num_images_per_prompt, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == timesteps.shape[0]: + prev_timestep = None + else: + prev_timestep = timesteps[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + self.maybe_free_model_hooks() + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8a3063141b5e62682ce3ad8e8e0d02473402665 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"] + _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky3 import Kandinsky3Pipeline + from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..5360632275b41b035f94d84d6156a6bafe84c613 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +import argparse +import fnmatch + +from safetensors.torch import load_file + +from diffusers import Kandinsky3UNet + + +MAPPING = { + "to_time_embed.1": "time_embedding.linear_1", + "to_time_embed.3": "time_embedding.linear_2", + "in_layer": "conv_in", + "out_layer.0": "conv_norm_out", + "out_layer.2": "conv_out", + "down_samples": "down_blocks", + "up_samples": "up_blocks", + "projection_lin": "encoder_hid_proj.projection_linear", + "projection_ln": "encoder_hid_proj.projection_norm", + "feature_pooling": "add_time_condition", + "to_query": "to_q", + "to_key": "to_k", + "to_value": "to_v", + "output_layer": "to_out.0", + "self_attention_block": "attentions.0", +} + +DYNAMIC_MAP = { + "resnet_attn_blocks.*.0": "resnets_in.*", + "resnet_attn_blocks.*.1": ("attentions.*", 1), + "resnet_attn_blocks.*.2": "resnets_out.*", +} +# MAPPING = {} + + +def convert_state_dict(unet_state_dict): + """ + Args: + Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model. + unet_model (torch.nn.Module): The original U-Net model. unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet + model to match keys with. + + Returns: + OrderedDict: The converted state dictionary. + """ + # Example of renaming logic (this will vary based on your model's architecture) + converted_state_dict = {} + for key in unet_state_dict: + new_key = key + for pattern, new_pattern in MAPPING.items(): + new_key = new_key.replace(pattern, new_pattern) + + for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items(): + has_matched = False + if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched: + star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1]) + + if isinstance(dyn_new_pattern, tuple): + new_star = star + dyn_new_pattern[-1] + dyn_new_pattern = dyn_new_pattern[0] + else: + new_star = star + + pattern = dyn_pattern.replace("*", str(star)) + new_pattern = dyn_new_pattern.replace("*", str(new_star)) + + new_key = new_key.replace(pattern, new_pattern) + has_matched = True + + converted_state_dict[new_key] = unet_state_dict[key] + + return converted_state_dict + + +def main(model_path, output_path): + # Load your original U-Net model + unet_state_dict = load_file(model_path) + + # Initialize your Kandinsky3UNet model + config = {} + + # Convert the state dict + converted_state_dict = convert_state_dict(unet_state_dict) + + unet = Kandinsky3UNet(config) + unet.load_state_dict(converted_state_dict) + + unet.save_pretrained(output_path) + print(f"Converted model saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format") + parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") + + args = parser.parse_args() + main(args.model_path, args.output_path) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py new file mode 100644 index 0000000000000000000000000000000000000000..97353c95c9c75b0ee20caaef1f6dd84e8f319dbd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py @@ -0,0 +1,588 @@ +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import Kandinsky3UNet, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForText2Image + >>> import torch + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0] + ``` + +""" + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: Kandinsky3UNet, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq + ) + + def process_embeds(self, embeddings, attention_mask, cut_context): + if cut_context: + embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) + max_seq_length = attention_mask.sum(-1).max() + 1 + embeddings = embeddings[:, :max_seq_length] + attention_mask = attention_mask[:, :max_seq_length] + return embeddings, attention_mask + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + _cut_context=False, + attention_mask: torch.Tensor | None = None, + negative_attention_mask: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 128 + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + if negative_prompt is not None: + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=128, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] + negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] + negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) + + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_attention_mask = torch.zeros_like(attention_mask) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + if negative_prompt_embeds.shape != prompt_embeds.shape: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + negative_attention_mask = None + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_inference_steps: int = 25, + guidance_scale: float = 3.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + height: int | None = 1024, + width: int | None = 1024, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + negative_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + latents=None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + cut_context = True + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, + ) + + self._guidance_scale = guidance_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + height, width = downscale_height_and_width(height, width, 8) + + latents = self.prepare_latents( + (batch_size * num_images_per_prompt, 4, height, width), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=attention_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond + # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + else: + image = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..beb4caafb6d323f6eece7ff4d5a00d74a28f0c64 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py @@ -0,0 +1,634 @@ +import inspect +from typing import Callable + +import PIL +import PIL.Image +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin +from ...models import Kandinsky3UNet, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A painting of the inside of a subway train with tiny raccoons." + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" + ... ) + + >>> generator = torch.Generator(device="cpu").manual_seed(0) + >>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + ``` +""" + + +class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->movq->unet->movq" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "negative_attention_mask", + "attention_mask", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: Kandinsky3UNet, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq + ) + movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8 + movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4 + self.image_processor = VaeImageProcessor( + vae_scale_factor=movq_scale_factor, + vae_latent_channels=movq_latent_channels, + resample="bicubic", + reducing_gap=1, + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def _process_embeds(self, embeddings, attention_mask, cut_context): + # return embeddings, attention_mask + if cut_context: + embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) + max_seq_length = attention_mask.sum(-1).max() + 1 + embeddings = embeddings[:, :max_seq_length] + attention_mask = attention_mask[:, :max_seq_length] + return embeddings, attention_mask + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + _cut_context=True, + attention_mask: torch.Tensor | None = None, + negative_attention_mask: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 128 + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + if negative_prompt is not None: + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=128, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] + negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] + negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) + + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_attention_mask = torch.zeros_like(attention_mask) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + if negative_prompt_embeds.shape != prompt_embeds.shape: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + negative_attention_mask = None + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + attention_mask=None, + negative_attention_mask=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if negative_prompt_embeds is not None and negative_attention_mask is None: + raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`") + + if negative_prompt_embeds is not None and negative_attention_mask is not None: + if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape: + raise ValueError( + "`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`" + f" {negative_attention_mask.shape}." + ) + + if prompt_embeds is not None and attention_mask is None: + raise ValueError("Please provide `attention_mask` along with `prompt_embeds`") + + if prompt_embeds is not None and attention_mask is not None: + if prompt_embeds.shape[:2] != attention_mask.shape: + raise ValueError( + "`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`" + f" {attention_mask.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None, + strength: float = 0.3, + num_inference_steps: int = 25, + guidance_scale: float = 3.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + negative_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask. Must provide if passing `prompt_embeds` directly. + negative_attention_mask (`torch.Tensor`, *optional*): + Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + cut_context = True + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + attention_mask, + negative_attention_mask, + ) + + self._guidance_scale = guidance_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + _cut_context=cut_context, + attention_mask=attention_mask, + negative_attention_mask=negative_attention_mask, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + # 5. Prepare latents + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=attention_mask, + )[0] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + attention_mask = callback_outputs.pop("attention_mask", attention_mask) + negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # post-processing + if not output_type == "latent": + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + image = self.image_processor.postprocess(image, output_type) + else: + image = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d417ed932b92172d88d5512c9dca8525dc3a7852 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] + _import_structure["pipeline_kandinsky_i2i"] = ["Kandinsky5I2IPipeline"] + _import_structure["pipeline_kandinsky_i2v"] = ["Kandinsky5I2VPipeline"] + _import_structure["pipeline_kandinsky_t2i"] = ["Kandinsky5T2IPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky import Kandinsky5T2VPipeline + from .pipeline_kandinsky_i2i import Kandinsky5I2IPipeline + from .pipeline_kandinsky_i2v import Kandinsky5I2VPipeline + from .pipeline_kandinsky_t2i import Kandinsky5T2IPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a9a6303dba6176009ea44a0c58b9ed56b50e507 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7e009ea7af60a663f7d52ae3e209f053ca526ab Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_i2i.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_i2i.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e7be0c29e0a6ae170e4b192a066ad365e5213e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_i2i.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_i2v.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_i2v.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20345bd8ce7ec5b1b5f729877840d0c5107a0455 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_i2v.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_t2i.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_t2i.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abd4937285c23c373d18d39b9f04651528c26a87 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_kandinsky_t2i.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7369428a3225a3fe5e4b8ac686fd56875242aa77 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py new file mode 100644 index 0000000000000000000000000000000000000000..1c94a8219e2af50b5bec75197c166cb2e05c01de --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -0,0 +1,968 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Callable + +import regex as re +import torch +from torch.nn import functional as F +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler + +# Add imports for offloading and tiling +from ...utils import ( + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline + >>> from diffusers.utils import export_to_video + + >>> # Available models: + >>> # kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers + + >>> model_id = "kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers" + >>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=121, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + + >>> export_to_video(output, "output.mp4", fps=24, quality=9) + ``` +""" + + +def basic_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Clean text using ftfy if available and unescape HTML entities. + """ + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Normalize whitespace in text by replacing multiple spaces with single space. + """ + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Apply both basic cleaning and whitespace normalization to prompts. + """ + text = whitespace_clean(basic_clean(text)) + return text + + +class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder Model [hunyuanvideo-community/HunyuanVideo + (vae)](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) to encode and decode videos to and from + latent representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKLHunyuanVideo, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.prompt_template = "\n".join( + [ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ] + ) + self.prompt_template_encode_start_idx = 129 + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_scale_factor(self, height: int, width: int) -> tuple: + """ + Calculate the scale factor based on resolution. + + Args: + height (int): Video height + width (int): Video width + + Returns: + tuple: Scale factor as (temporal_scale, height_scale, width_scale) + """ + + def between_480p(x): + return 480 <= x <= 854 + + if between_480p(height) and between_480p(width): + return (1, 2, 2) + else: + return (1, 3.16, 3.16) + + @staticmethod + def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor: + """ + Create a sparse temporal attention (STA) mask for efficient video generation. + + This method generates a mask that limits attention to nearby frames and spatial positions, reducing + computational complexity for video generation. + + Args: + T (int): Number of temporal frames + H (int): Height in latent space + W (int): Width in latent space + wT (int): Temporal attention window size + wH (int): Height attention window size + wW (int): Width attention window size + device (str): Device to create tensor on + + Returns: + torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) + """ + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten() + sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2) + return sta.reshape(T * H * W, T * H * W) + + def get_sparse_params(self, sample, device): + """ + Generate sparse attention parameters for the transformer based on sample dimensions. + + This method computes the sparse attention configuration needed for efficient video processing in the + transformer model. + + Args: + sample (torch.Tensor): Input sample tensor + device (torch.device): Device to place tensors on + + Returns: + Dict: Dictionary containing sparse attention parameters + """ + assert self.transformer.config.patch_size[0] == 1 + B, T, H, W, _ = sample.shape + T, H, W = ( + T // self.transformer.config.patch_size[0], + H // self.transformer.config.patch_size[1], + W // self.transformer.config.patch_size[2], + ) + if self.transformer.config.attention_type == "nabla": + sta_mask = self.fast_sta_nabla( + T, + H // 8, + W // 8, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, + device=device, + ) + + sparse_params = { + "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), + "attention_type": self.transformer.config.attention_type, + "to_fractal": True, + "P": self.transformer.config.attention_P, + "wT": self.transformer.config.attention_wT, + "wW": self.transformer.config.attention_wW, + "wH": self.transformer.config.attention_wH, + "add_sta": self.transformer.config.attention_add_sta, + "visual_shape": (T, H, W), + "method": self.transformer.config.attention_method, + } + else: + sparse_params = None + + return sparse_params + + def _encode_prompt_qwen( + self, + prompt: str | list[str], + device: torch.device | None = None, + max_sequence_length: int = 256, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + video generation. + + Args: + prompt (str | list[str]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + full_texts = [self.prompt_template.format(p) for p in prompt] + max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length + + untruncated_ids = self.tokenizer( + text=full_texts, + images=None, + videos=None, + return_tensors="pt", + padding="longest", + )["input_ids"] + + if untruncated_ids.shape[-1] > max_allowed_len: + for i, text in enumerate(full_texts): + tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2] + removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :]) + if len(removed_text) > 0: + full_texts[i] = text[: -len(removed_text)] + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_allowed_len, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] + + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + + return embeds.to(dtype), cu_seqlens + + def _encode_prompt_clip( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + + Args: + prompt (str | list[str]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + return pooled_embed.to(dtype) + + def encode_prompt( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes a single prompt (positive or negative) into text encoder hidden states. + + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for video generation. + + Args: + prompt (`str` or `list[str]`): + Prompt to be encoded. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_videos_per_prompt + 1,) + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + batch_size = len(prompt) + + prompt = [prompt_clean(p) for p in prompt] + + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] + + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + dtype=dtype, + ) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_videos_per_prompt + # Qwen embeddings: repeat sequence for each video, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) + + # CLIP embeddings: repeat for each video + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, num_videos_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) + + # Repeat cumulative sequence lengths for num_videos_per_prompt + # Original cu_seqlens: [0, len1, len1+len2, ...] + # Need to repeat the differences and reconstruct for repeated prompts + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_videos_per_prompt + repeated_lengths = original_lengths.repeat_interleave( + num_videos_per_prompt + ) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + height: Video height + width: Video width + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError("max_sequence_length must be less than 1024") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): + raise ValueError( + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." + ) + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Prepare initial latent variables for video generation. + + This method creates random noise latents or uses provided latents as starting point for the denoising process. + + Args: + batch_size (int): Number of videos to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated video + width (int): Width of generated video + num_frames (int): Number of frames in video + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor + """ + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if self.transformer.visual_cond: + # For visual conditioning, concatenate with zeros and mask + visual_cond = torch.zeros_like(latents) + visual_cond_mask = torch.zeros( + [ + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + 1, + ], + dtype=latents.dtype, + device=latents.device, + ) + latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) + + return latents + + @property + def guidance_scale(self): + """Get the current guidance scale value.""" + return self._guidance_scale + + @property + def num_timesteps(self): + """Get the number of denoising timesteps.""" + return self._num_timesteps + + @property + def interrupt(self): + """Check if generation has been interrupted.""" + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds_qwen: torch.Tensor | None = None, + prompt_embeds_clip: torch.Tensor | None = None, + negative_prompt_embeds_qwen: torch.Tensor | None = None, + negative_prompt_embeds_clip: torch.Tensor | None = None, + prompt_cu_seqlens: torch.Tensor | None = None, + negative_prompt_cu_seqlens: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `512`): + The height in pixels of the generated video. + width (`int`, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `25`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyPipelineOutput`]. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. + + Examples: + + Returns: + [`~KandinskyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds_qwen.shape[0] + + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if self.guidance_scale > 1.0: + if negative_prompt is None: + negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = ( + self.encode_prompt( + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_visual_dim + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, + ) + + # 6. Prepare rope positions for positional encoding + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + visual_rope_pos = [ + torch.arange(num_latent_frames, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device) + if negative_prompt_cu_seqlens is not None + else None + ) + + # 7. Calculate dynamic scale factor based on resolution + scale_factor = self._get_scale_factor(height, width) + + # 8. Sparse Params for efficient attention + sparse_params = self.get_sparse_params(latents, device) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) + + # Predict noise residual + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) + # Compute previous sample using the scheduler + latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 10. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] + + # 11. Decode latents to video + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + video = latents.reshape( + batch_size, + num_videos_per_prompt, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + num_channels_latents, + ) + video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] + video = video.reshape( + batch_size * num_videos_per_prompt, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + # Normalize and decode through VAE + video = video / self.vae.config.scaling_factor + video = self.vae.decode(video).sample + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KandinskyPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py new file mode 100644 index 0000000000000000000000000000000000000000..244db7300767b7d8a901b5301a6a2089af3dc6e1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py @@ -0,0 +1,861 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Callable + +import numpy as np +import regex as re +import torch +from torch.nn import functional as F +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKL +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler + +# Add imports for offloading and tiling +from ...utils import ( + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5I2IPipeline + + >>> # Available models: + >>> # kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers + >>> # kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers + + >>> model_id = "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers" + >>> pipe = Kandinsky5I2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt="", + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=3.5, + ... ).frames[0] + ``` +""" + + +def basic_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Clean text using ftfy if available and unescape HTML entities. + """ + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Normalize whitespace in text by replacing multiple spaces with single space. + """ + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Apply both basic cleaning and whitespace normalization to prompts. + """ + text = whitespace_clean(basic_clean(text)) + return text + + +class Kandinsky5I2IPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for image-to-image generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder Model [black-forest-labs/FLUX.1-dev + (vae)](https://huggingface.co/black-forest-labs/FLUX.1-dev) to encode and decode videos to and from latent + representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + self.prompt_template = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|><|im_start|>user{}<|vision_start|><|image_pad|><|vision_end|><|im_end|>" + self.prompt_template_encode_start_idx = 55 + + self.vae_scale_factor_spatial = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.resolutions = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] + + def _encode_prompt_qwen( + self, + prompt: list[str], + image: PipelineImageInput | None = None, + device: torch.device | None = None, + max_sequence_length: int = 1024, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + image generation. + + Args: + prompt list[str]: Input list of prompts + image (PipelineImageInput): Input list of images to condition the generation on + device (torch.device): Device to run encoding on + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + if not isinstance(image, list): + image = [image] + image = [i.resize((i.size[0] // 2, i.size[1] // 2)) for i in image] + full_texts = [self.prompt_template.format(p) for p in prompt] + max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length + + untruncated_ids = self.tokenizer( + text=full_texts, + images=image, + videos=None, + return_tensors="pt", + padding="longest", + )["input_ids"] + + if untruncated_ids.shape[-1] > max_allowed_len: + for i, text in enumerate(full_texts): + tokens = untruncated_ids[i] + num_image_tokens = (tokens == self.tokenizer.image_token_id).sum() + tokens = tokens[tokens != self.tokenizer.image_token_id][self.prompt_template_encode_start_idx : -3] + removed_text = self.tokenizer.decode(tokens[max_sequence_length - num_image_tokens - 3 :]) + if len(removed_text) > 0: + full_texts[i] = text[: -len(removed_text)] + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + inputs = self.tokenizer( + text=full_texts, + images=image, + videos=None, + max_length=max_allowed_len, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + embeds = self.text_encoder( + **inputs, + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] + + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + + return embeds.to(dtype), cu_seqlens + + def _encode_prompt_clip( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + + Args: + prompt (str | list[str]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + return pooled_embed.to(dtype) + + def encode_prompt( + self, + prompt: str | list[str], + image: torch.Tensor, + num_images_per_prompt: int = 1, + max_sequence_length: int = 1024, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes a single prompt (positive or negative) into text encoder hidden states. + + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for image generation. + + Args: + prompt (`str` or `list[str]`): + Prompt to be encoded. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate per prompt. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for text encoding. Must be less than 1024 + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_images_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_images_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_images_per_prompt + 1,) + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + batch_size = len(prompt) + + prompt = [prompt_clean(p) for p in prompt] + + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + image=image, + device=device, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] + + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + dtype=dtype, + ) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_images_per_prompt + # Qwen embeddings: repeat sequence for each image, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_images_per_prompt, 1 + ) # [batch_size, seq_len * num_images_per_prompt, embed_dim] + # Reshape to [batch_size * num_images_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_images_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) + + # CLIP embeddings: repeat for each image + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_images_per_prompt, 1 + ) # [batch_size, num_images_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_images_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_images_per_prompt, -1) + + # Repeat cumulative sequence lengths for num_images_per_prompt + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_images_per_prompt + repeated_lengths = original_lengths.repeat_interleave( + num_images_per_prompt + ) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + image: Input image for conditioning + height: Image height + width: Image width + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError("max_sequence_length must be less than 1024") + + if image is None: + raise ValueError("`image` must be provided for image-to-image generation") + + if (width, height) not in self.resolutions: + resolutions_str = ",".join([f"({w},{h})" for w, h in self.resolutions]) + logger.warning( + f"`height` and `width` have to be one of {resolutions_str}, but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): + raise ValueError( + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." + ) + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 1024, + width: int = 1024, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Prepare initial latent variables for image-to-image generation. + + This method creates random noise latents with encoded image, + + Args: + image (PipelineImageInput): Input image to condition the generation on + batch_size (int): Number of images to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated image + width (int): Width of generated image + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor with encoded image + """ + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Generate random noise for all frames + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # Encode the input image to use as first frame + # Preprocess image + image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype) + # Encode image to latents using VAE + with torch.no_grad(): + image_latents = self.vae.encode(image_tensor).latent_dist.sample(generator=generator) + image_latents = image_latents.unsqueeze(2) # Add temporal dimension + + # Normalize latents if needed + if hasattr(self.vae.config, "scaling_factor"): + image_latents = image_latents * self.vae.config.scaling_factor + + # Reshape to match latent dimensions [batch, 1, height, width, channels] + image_latents = image_latents.permute(0, 2, 3, 4, 1) # [batch, 1, H, W, C] + latents = torch.cat([latents, image_latents, torch.ones_like(latents[..., :1])], -1) + + return latents + + @property + def guidance_scale(self): + """Get the current guidance scale value.""" + return self._guidance_scale + + @property + def num_timesteps(self): + """Get the number of denoising timesteps.""" + return self._num_timesteps + + @property + def interrupt(self): + """Check if generation has been interrupted.""" + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds_qwen: torch.Tensor | None = None, + prompt_embeds_clip: torch.Tensor | None = None, + negative_prompt_embeds_qwen: torch.Tensor | None = None, + negative_prompt_embeds_clip: torch.Tensor | None = None, + prompt_cu_seqlens: torch.Tensor | None = None, + negative_prompt_cu_seqlens: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + The call function to the pipeline for image-to-image generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds_qwen (`torch.Tensor`, *optional*): + Pre-generated Qwen text embeddings. + prompt_embeds_clip (`torch.Tensor`, *optional*): + Pre-generated CLIP text embeddings. + negative_prompt_embeds_qwen (`torch.Tensor`, *optional*): + Pre-generated Qwen negative text embeddings. + negative_prompt_embeds_clip (`torch.Tensor`, *optional*): + Pre-generated CLIP negative text embeddings. + prompt_cu_seqlens (`torch.Tensor`, *optional*): + Pre-generated cumulative sequence lengths for Qwen positive prompt. + negative_prompt_cu_seqlens (`torch.Tensor`, *optional*): + Pre-generated cumulative sequence lengths for Qwen negative prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyImagePipelineOutput`]. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `1024`): + The maximum sequence length for text and image qwen encoding. Must be less than 1024 + + Examples: + + Returns: + [`~KandinskyImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct + if height is None and width is None: + width, height = image[0].size if isinstance(image, list) else image.size + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + if (width, height) not in self.resolutions: + width, height = self.resolutions[ + np.argmin([abs((i[0] / i[1]) - (width / height)) for i in self.resolutions]) + ] + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds_qwen.shape[0] + + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + image=image, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if self.guidance_scale > 1.0: + if negative_prompt is None: + negative_prompt = "" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = ( + self.encode_prompt( + prompt=negative_prompt, + image=image, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables with image conditioning + num_channels_latents = self.transformer.config.in_visual_dim + latents = self.prepare_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare rope positions for positional encoding + visual_rope_pos = [ + torch.arange(1, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device) + if negative_prompt_cu_seqlens is not None + else None + ) + + # 7. Calculate dynamic scale factor based on resolution + scale_factor = [1.0, 1.0, 1.0] + + # 8. Sparse Params for efficient attention + sparse_params = None + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt) + + # Predict noise residual + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) + + latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity[:, :], t, latents[:, :, :, :, :num_channels_latents], return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] + + # 10. Decode latents to image + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + latents = latents.reshape( + batch_size, + num_images_per_prompt, + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + num_channels_latents, + ) + latents = latents.permute(0, 1, 5, 2, 3, 4) # [batch, num_images, channels, 1, height, width] + latents = latents.reshape( + batch_size * num_images_per_prompt, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + # Normalize and decode through VAE + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents).sample + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KandinskyImagePipelineOutput(image=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4bb182d248dc2d97ce4d8e605c7bb116f07cb2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py @@ -0,0 +1,1052 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Callable + +import regex as re +import torch +from torch.nn import functional as F +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler + +# Add imports for offloading and tiling +from ...utils import ( + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5I2VPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> # Available models: + >>> # kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers + + >>> model_id = "kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers" + >>> pipe = Kandinsky5I2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> prompt = "An astronaut floating in space with Earth in the background, cinematic shot" + >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=121, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + + >>> export_to_video(output, "output.mp4", fps=24, quality=9) + ``` +""" + + +def basic_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Clean text using ftfy if available and unescape HTML entities. + """ + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Normalize whitespace in text by replacing multiple spaces with single space. + """ + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Apply both basic cleaning and whitespace normalization to prompts. + """ + text = whitespace_clean(basic_clean(text)) + return text + + +class Kandinsky5I2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder Model [hunyuanvideo-community/HunyuanVideo + (vae)](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) to encode and decode videos to and from + latent representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKLHunyuanVideo, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.prompt_template = "\n".join( + [ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ] + ) + self.prompt_template_encode_start_idx = 129 + + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_scale_factor(self, height: int, width: int) -> tuple: + """ + Calculate the scale factor based on resolution. + + Args: + height (int): Video height + width (int): Video width + + Returns: + tuple: Scale factor as (temporal_scale, height_scale, width_scale) + """ + + def between_480p(x): + return 480 <= x <= 854 + + if between_480p(height) and between_480p(width): + return (1, 2, 2) + else: + return (1, 3.16, 3.16) + + @staticmethod + def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor: + """ + Create a sparse temporal attention (STA) mask for efficient video generation. + + This method generates a mask that limits attention to nearby frames and spatial positions, reducing + computational complexity for video generation. + + Args: + T (int): Number of temporal frames + H (int): Height in latent space + W (int): Width in latent space + wT (int): Temporal attention window size + wH (int): Height attention window size + wW (int): Width attention window size + device (str): Device to create tensor on + + Returns: + torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) + """ + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten() + sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2) + return sta.reshape(T * H * W, T * H * W) + + def get_sparse_params(self, sample, device): + """ + Generate sparse attention parameters for the transformer based on sample dimensions. + + This method computes the sparse attention configuration needed for efficient video processing in the + transformer model. + + Args: + sample (torch.Tensor): Input sample tensor + device (torch.device): Device to place tensors on + + Returns: + Dict: Dictionary containing sparse attention parameters + """ + assert self.transformer.config.patch_size[0] == 1 + B, T, H, W, _ = sample.shape + T, H, W = ( + T // self.transformer.config.patch_size[0], + H // self.transformer.config.patch_size[1], + W // self.transformer.config.patch_size[2], + ) + if self.transformer.config.attention_type == "nabla": + sta_mask = self.fast_sta_nabla( + T, + H // 8, + W // 8, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, + device=device, + ) + + sparse_params = { + "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), + "attention_type": self.transformer.config.attention_type, + "to_fractal": True, + "P": self.transformer.config.attention_P, + "wT": self.transformer.config.attention_wT, + "wW": self.transformer.config.attention_wW, + "wH": self.transformer.config.attention_wH, + "add_sta": self.transformer.config.attention_add_sta, + "visual_shape": (T, H, W), + "method": self.transformer.config.attention_method, + } + else: + sparse_params = None + + return sparse_params + + def _encode_prompt_qwen( + self, + prompt: str | list[str], + device: torch.device | None = None, + max_sequence_length: int = 256, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + video generation. + + Args: + prompt (str | list[str]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + full_texts = [self.prompt_template.format(p) for p in prompt] + max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length + + untruncated_ids = self.tokenizer( + text=full_texts, + images=None, + videos=None, + return_tensors="pt", + padding="longest", + )["input_ids"] + + if untruncated_ids.shape[-1] > max_allowed_len: + for i, text in enumerate(full_texts): + tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2] + removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :]) + if len(removed_text) > 0: + full_texts[i] = text[: -len(removed_text)] + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_allowed_len, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] + + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + + return embeds.to(dtype), cu_seqlens + + def _encode_prompt_clip( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + + Args: + prompt (str | list[str]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + return pooled_embed.to(dtype) + + @staticmethod + def adaptive_mean_std_normalization(source, reference): + source_mean = source.mean(dim=(1, 2, 3, 4), keepdim=True) + source_std = source.std(dim=(1, 2, 3, 4), keepdim=True) + # magic constants - limit changes in latents + clump_mean_low = 0.05 + clump_mean_high = 0.1 + clump_std_low = 0.1 + clump_std_high = 0.25 + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / source_std + normalized = normalized * reference_std + reference_mean + + return normalized + + def normalize_first_frame(self, latents, reference_frames=5, clump_values=False): + latents_copy = latents.clone() + samples = latents_copy + + if samples.shape[1] <= 1: + return (latents, "Only one frame, no normalization needed") + + nFr = 4 + first_frames = samples.clone()[:, :nFr] + reference_frames_data = samples[:, nFr : nFr + min(reference_frames, samples.shape[1] - 1)] + + normalized_first = self.adaptive_mean_std_normalization(first_frames, reference_frames_data) + if clump_values: + min_val = reference_frames_data.min() + max_val = reference_frames_data.max() + normalized_first = torch.clamp(normalized_first, min_val, max_val) + + samples[:, :nFr] = normalized_first + + return samples + + def encode_prompt( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes a single prompt (positive or negative) into text encoder hidden states. + + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for video generation. + + Args: + prompt (`str` or `list[str]`): + Prompt to be encoded. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_videos_per_prompt + 1,) + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + batch_size = len(prompt) + + prompt = [prompt_clean(p) for p in prompt] + + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] + + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + dtype=dtype, + ) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_videos_per_prompt + # Qwen embeddings: repeat sequence for each video, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) + + # CLIP embeddings: repeat for each video + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, num_videos_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) + + # Repeat cumulative sequence lengths for num_videos_per_prompt + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_videos_per_prompt + repeated_lengths = original_lengths.repeat_interleave( + num_videos_per_prompt + ) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + image: Input image for conditioning + height: Video height + width: Video width + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError("max_sequence_length must be less than 1024") + + if image is None: + raise ValueError("`image` must be provided for image-to-video generation") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): + raise ValueError( + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." + ) + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Prepare initial latent variables for image-to-video generation. + + This method creates random noise latents for all frames except the first frame, which is replaced with the + encoded input image. + + Args: + image (PipelineImageInput): Input image to condition the generation on + batch_size (int): Number of videos to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated video + width (int): Width of generated video + num_frames (int): Number of frames in video + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor with first frame as encoded image + """ + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Generate random noise for all frames + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # Encode the input image to use as first frame + # Preprocess image + image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype) + + # Encode image to latents using VAE + with torch.no_grad(): + # Convert image to video format [batch, channels, 1, height, width] + image_video = image_tensor.unsqueeze(2) # Add temporal dimension + image_latents = self.vae.encode(image_video).latent_dist.sample(generator=generator) + + # Normalize latents if needed + if hasattr(self.vae.config, "scaling_factor"): + image_latents = image_latents * self.vae.config.scaling_factor + + # Reshape to match latent dimensions [batch, frames, height, width, channels] + image_latents = image_latents.permute(0, 2, 3, 4, 1) # [batch, 1, H, W, C] + + # Replace first frame with encoded image + latents[:, 0:1] = image_latents + + if self.transformer.visual_cond: + # For visual conditioning, concatenate with zeros and mask + visual_cond = torch.zeros_like(latents) + visual_cond_mask = torch.zeros( + [ + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + 1, + ], + dtype=latents.dtype, + device=latents.device, + ) + + visual_cond_mask[:, 0:1] = 1 + visual_cond[:, 0:1] = image_latents + + latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) + + return latents + + @property + def guidance_scale(self): + """Get the current guidance scale value.""" + return self._guidance_scale + + @property + def num_timesteps(self): + """Get the number of denoising timesteps.""" + return self._num_timesteps + + @property + def interrupt(self): + """Check if generation has been interrupted.""" + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds_qwen: torch.Tensor | None = None, + prompt_embeds_clip: torch.Tensor | None = None, + negative_prompt_embeds_qwen: torch.Tensor | None = None, + negative_prompt_embeds_clip: torch.Tensor | None = None, + prompt_cu_seqlens: torch.Tensor | None = None, + negative_prompt_cu_seqlens: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for image-to-video generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `512`): + The height in pixels of the generated video. + width (`int`, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds_qwen (`torch.Tensor`, *optional*): + Pre-generated Qwen text embeddings. + prompt_embeds_clip (`torch.Tensor`, *optional*): + Pre-generated CLIP text embeddings. + negative_prompt_embeds_qwen (`torch.Tensor`, *optional*): + Pre-generated Qwen negative text embeddings. + negative_prompt_embeds_clip (`torch.Tensor`, *optional*): + Pre-generated CLIP negative text embeddings. + prompt_cu_seqlens (`torch.Tensor`, *optional*): + Pre-generated cumulative sequence lengths for Qwen positive prompt. + negative_prompt_cu_seqlens (`torch.Tensor`, *optional*): + Pre-generated cumulative sequence lengths for Qwen negative prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyPipelineOutput`]. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. + + Examples: + + Returns: + [`~KandinskyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated videos. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds_qwen.shape[0] + + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if self.guidance_scale > 1.0: + if negative_prompt is None: + negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = ( + self.encode_prompt( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables with image conditioning + num_channels_latents = self.transformer.config.in_visual_dim + latents = self.prepare_latents( + image=image, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare rope positions for positional encoding + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + visual_rope_pos = [ + torch.arange(num_latent_frames, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device) + if negative_prompt_cu_seqlens is not None + else None + ) + + # 7. Calculate dynamic scale factor based on resolution + scale_factor = self._get_scale_factor(height, width) + + # 8. Sparse Params for efficient attention + sparse_params = self.get_sparse_params(latents, device) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) + + # Predict noise residual + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) + + latents[:, 1:, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity[:, 1:], t, latents[:, 1:, :, :, :num_channels_latents], return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] + + # 10. fix mesh artifacts + latents = self.normalize_first_frame(latents) + + # 11. Decode latents to video + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + video = latents.reshape( + batch_size, + num_videos_per_prompt, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + num_channels_latents, + ) + video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] + video = video.reshape( + batch_size * num_videos_per_prompt, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + # Normalize and decode through VAE + video = video / self.vae.config.scaling_factor + video = self.vae.decode(video).sample + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KandinskyPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py new file mode 100644 index 0000000000000000000000000000000000000000..2a58d4bed33a42fe71c77f134bd28d73155a96ed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py @@ -0,0 +1,816 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Callable + +import numpy as np +import regex as re +import torch +from torch.nn import functional as F +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKL +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler + +# Add imports for offloading and tiling +from ...utils import ( + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2IPipeline + + >>> # Available models: + >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers + >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers + + >>> model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers" + >>> pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt="", + ... height=1024, + ... width=1024, + ... num_inference_steps=50, + ... guidance_scale=3.5, + ... ).frames[0] + ``` +""" + + +def basic_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Clean text using ftfy if available and unescape HTML entities. + """ + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Normalize whitespace in text by replacing multiple spaces with single space. + """ + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + """ + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + + Apply both basic cleaning and whitespace normalization to prompts. + """ + text = whitespace_clean(basic_clean(text)) + return text + + +class Kandinsky5T2IPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder Model [black-forest-labs/FLUX.1-dev + (vae)](https://huggingface.co/black-forest-labs/FLUX.1-dev) to encode and decode videos to and from latent + representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.prompt_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.prompt_template_encode_start_idx = 41 + + self.vae_scale_factor_spatial = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.resolutions = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] + + def _encode_prompt_qwen( + self, + prompt: list[str], + device: torch.device | None = None, + max_sequence_length: int = 512, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + image generation. + + Args: + prompt list[str]: Input list of prompts + device (torch.device): Device to run encoding on + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + full_texts = [self.prompt_template.format(p) for p in prompt] + max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length + + untruncated_ids = self.tokenizer( + text=full_texts, + images=None, + videos=None, + return_tensors="pt", + padding="longest", + )["input_ids"] + + if untruncated_ids.shape[-1] > max_allowed_len: + for i, text in enumerate(full_texts): + tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2] + removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :]) + if len(removed_text) > 0: + full_texts[i] = text[: -len(removed_text)] + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_allowed_len, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + + return embeds.to(dtype), cu_seqlens + + def _encode_prompt_clip( + self, + prompt: str | list[str], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + + Args: + prompt (str | list[str]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + return pooled_embed.to(dtype) + + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes a single prompt (positive or negative) into text encoder hidden states. + + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for image generation. + + Args: + prompt (`str` or `list[str]`): + Prompt to be encoded. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate per prompt. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. Must be less than 1024 + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_images_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_images_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_images_per_prompt + 1,) + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + + batch_size = len(prompt) + + prompt = [prompt_clean(p) for p in prompt] + + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] + + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + dtype=dtype, + ) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_images_per_prompt + # Qwen embeddings: repeat sequence for each image, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_images_per_prompt, 1 + ) # [batch_size, seq_len * num_images_per_prompt, embed_dim] + # Reshape to [batch_size * num_images_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_images_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) + + # CLIP embeddings: repeat for each image + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_images_per_prompt, 1 + ) # [batch_size, num_images_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_images_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_images_per_prompt, -1) + + # Repeat cumulative sequence lengths for num_images_per_prompt + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_images_per_prompt + repeated_lengths = original_lengths.repeat_interleave( + num_images_per_prompt + ) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + height: Image height + width: Image width + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError("max_sequence_length must be less than 1024") + + if (width, height) not in self.resolutions: + resolutions_str = ",".join([f"({w},{h})" for w, h in self.resolutions]) + logger.warning( + f"`height` and `width` have to be one of {resolutions_str}, but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): + raise ValueError( + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." + ) + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 1024, + width: int = 1024, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Prepare initial latent variables for text-to-image generation. + + This method creates random noise latents + + Args: + batch_size (int): Number of images to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated image + width (int): Width of generated image + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor + """ + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Generate random noise + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + """Get the current guidance scale value.""" + return self._guidance_scale + + @property + def num_timesteps(self): + """Get the number of denoising timesteps.""" + return self._num_timesteps + + @property + def interrupt(self): + """Check if generation has been interrupted.""" + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds_qwen: torch.Tensor | None = None, + prompt_embeds_clip: torch.Tensor | None = None, + negative_prompt_embeds_qwen: torch.Tensor | None = None, + negative_prompt_embeds_clip: torch.Tensor | None = None, + prompt_cu_seqlens: torch.Tensor | None = None, + negative_prompt_cu_seqlens: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for text-to-image generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `1024`): + The height in pixels of the generated image. + width (`int`, defaults to `1024`): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds_qwen (`torch.Tensor`, *optional*): + Pre-generated Qwen text embeddings. + prompt_embeds_clip (`torch.Tensor`, *optional*): + Pre-generated CLIP text embeddings. + negative_prompt_embeds_qwen (`torch.Tensor`, *optional*): + Pre-generated Qwen negative text embeddings. + negative_prompt_embeds_clip (`torch.Tensor`, *optional*): + Pre-generated CLIP negative text embeddings. + prompt_cu_seqlens (`torch.Tensor`, *optional*): + Pre-generated cumulative sequence lengths for Qwen positive prompt. + negative_prompt_cu_seqlens (`torch.Tensor`, *optional*): + Pre-generated cumulative sequence lengths for Qwen negative prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyImagePipelineOutput`]. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. + + Examples: + + Returns: + [`~KandinskyImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + if (width, height) not in self.resolutions: + width, height = self.resolutions[ + np.argmin([abs((i[0] / i[1]) - (width / height)) for i in self.resolutions]) + ] + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds_qwen.shape[0] + + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if self.guidance_scale > 1.0: + if negative_prompt is None: + negative_prompt = "" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = ( + self.encode_prompt( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_visual_dim + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare rope positions for positional encoding + visual_rope_pos = [ + torch.arange(1, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device) + if negative_prompt_cu_seqlens is not None + else None + ) + + # 7. Calculate dynamic scale factor based on resolution + scale_factor = [1.0, 1.0, 1.0] + + # 8. Sparse Params for efficient attention + sparse_params = None + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt) + + # Predict noise residual + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) + + latents = self.scheduler.step(pred_velocity[:, :], t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] + + # 10. Decode latents to image + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + latents = latents.reshape( + batch_size, + num_images_per_prompt, + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + num_channels_latents, + ) + latents = latents.permute(0, 1, 5, 2, 3, 4) # [batch, num_images, channels, 1, height, width] + latents = latents.reshape( + batch_size * num_images_per_prompt, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + # Normalize and decode through VAE + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents).sample + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KandinskyImagePipelineOutput(image=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..9b41d2f14424f29c90cea1d064fe144d5513d294 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kandinsky5/pipeline_output.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class KandinskyPipelineOutput(BaseOutput): + r""" + Output class for kandinsky video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +@dataclass +class KandinskyImagePipelineOutput(BaseOutput): + r""" + Output class for kandinsky image pipelines. + + Args: + image (`torch.Tensor`, `np.ndarray`, or list[PIL.Image.Image]): + List of image outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image. It can also be a NumPy array or Torch tensor of shape `(batch_size, channels, height, + width)`. + """ + + image: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..671d22e9f433802ec82ebe89fae4cef6c001064b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_sentencepiece_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) +else: + _import_structure["pipeline_kolors"] = ["KolorsPipeline"] + _import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"] + _import_structure["text_encoder"] = ["ChatGLMModel"] + _import_structure["tokenizer"] = ["ChatGLMTokenizer"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_and_sentencepiece_objects import * + + else: + from .pipeline_kolors import KolorsPipeline + from .pipeline_kolors_img2img import KolorsImg2ImgPipeline + from .text_encoder import ChatGLMModel + from .tokenizer import ChatGLMTokenizer + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed0d938d7e1c42b6a6652fb56bbb9b7c3761bad8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_kolors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_kolors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f96cf06b763822fd2020d299214d309166b70b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_kolors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_kolors_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_kolors_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2012e2b08b8b60ecf1d2bcc9dad7a644eb9fc4e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_kolors_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..706a894c444b6f249f770c2f1ec0a627d6a84618 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/text_encoder.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/text_encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..779c3f7ba60195db22593d663b3bd906b7c7ec69 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/text_encoder.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/tokenizer.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/tokenizer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8051534a8da1a586518dc4221faf7586639f3922 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/__pycache__/tokenizer.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_kolors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_kolors.py new file mode 100644 index 0000000000000000000000000000000000000000..1e11faf8b9b6b2e3194a315de56b9047199f0649 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_kolors.py @@ -0,0 +1,1063 @@ +# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import KolorsPipelineOutput +from .text_encoder import ChatGLMModel +from .tokenizer import ChatGLMTokenizer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import KolorsPipeline + + >>> pipe = KolorsPipeline.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = ( + ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" + ... ) + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + def encode_prompt( + self, + prompt, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_kolors_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b5192672168d2105679de7f37d110936f02acb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -0,0 +1,1243 @@ +# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import empty_device_cache, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import KolorsPipelineOutput +from .text_encoder import ChatGLMModel +from .tokenizer import ChatGLMTokenizer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import KolorsImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = KolorsImg2ImgPipeline.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = ( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png" + ... ) + + + >>> init_image = load_image(url) + >>> prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed." + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_start(self): + return self._denoising_start + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.3, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..0418191cc6e50059d7b564b641ef8ca35a61deef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class KolorsPipelineOutput(BaseOutput): + """ + Output class for Kolors pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/text_encoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..434f4fed6fbbb5f359e8ed70932a02d183bb83a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/text_encoder.py @@ -0,0 +1,864 @@ +# Copyright 2025 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import LayerNorm +from torch.nn.utils import skip_init +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split(".")[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [ + k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] + ] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, is_causal=True + ) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], + output_size[3], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones( + output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool + ) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> list[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear( + config.hidden_size, + self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear( + self.projection_size, + config.hidden_size, + bias=config.add_bias_linear, + device=device, + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, + and project the state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://huggingface.co/papers/2002.05202 + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=None, + use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches=None, + use_cache: bool | None = True, + output_hidden_states: bool | None = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if torch.is_grad_enabled() and self.gradient_checkpointing: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if torch.is_grad_enabled() and self.gradient_checkpointing: + layer_ret = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache + ) + else: + layer_ret = layer( + hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1 + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit transposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device) + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size, + prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size), + ) + else: + self.embedding = torch.nn.Embedding( + config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + ) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method( + nn.Linear, + config.hidden_size, + config.padded_vocab_size, + bias=False, + **init_kwargs, + ) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + if hasattr(self, "post_init"): + self.post_init() + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: torch.Tensor | None = None, + attention_mask: torch.BoolTensor | None = None, + full_attention_mask: torch.BoolTensor | None = None, + past_key_values: tuple[tuple[torch.Tensor, torch.Tensor], ...] | None = None, + inputs_embeds: torch.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", None) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/tokenizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a1831bdeeef5369e7299fc76f6426c5171b06a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/kolors/tokenizer.py @@ -0,0 +1,337 @@ +# Copyright 2025 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re + +from sentencepiece import SentencePieceProcessor +from transformers import PreTrainedTokenizer +from transformers.tokenization_utils_base import BatchEncoding, EncodedInput +from transformers.utils import PaddingStrategy + + +class SPTokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.unk_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens]) + + def tokenize(self, s: str, encode_special_tokens=False): + if encode_special_tokens: + last_index = 0 + t = [] + for match in re.finditer(self.role_special_token_expression, s): + if last_index < match.start(): + t.extend(self.sp_model.EncodeAsPieces(s[last_index : match.start()])) + t.append(s[match.start() : match.end()]) + last_index = match.end() + if last_index < len(s): + t.extend(self.sp_model.EncodeAsPieces(s[last_index:])) + return t + else: + return self.sp_model.EncodeAsPieces(s) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> list[int]: + assert isinstance(s, str) + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: list[int]) -> str: + text, buffer = "", [] + for token in t: + if token in self.index_special_tokens: + if buffer: + text += self.sp_model.decode(buffer) + buffer = [] + text += self.index_special_tokens[token] + else: + buffer.append(token) + if buffer: + text += self.sp_model.decode(buffer) + return text + + def decode_tokens(self, tokens: list[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return self.index_special_tokens[index] + if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: + return "" + return self.sp_model.IdToPiece(index) + + +class ChatGLMTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "tokenizer.model"} + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__( + self, + vocab_file, + padding_side="left", + clean_up_tokenization_spaces=False, + encode_special_tokens=False, + **kwargs, + ): + self.name = "GLMTokenizer" + + self.vocab_file = vocab_file + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + "": self.tokenizer.bos_id, + "": self.tokenizer.eos_id, + "": self.tokenizer.pad_id, + } + self.encode_special_tokens = encode_special_tokens + super().__init__( + padding_side=padding_side, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + encode_special_tokens=encode_special_tokens, + **kwargs, + ) + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" + return self.tokenizer.special_tokens[token] + + @property + def unk_token(self) -> str: + return "" + + @unk_token.setter + def unk_token(self, value: str): + self._unk_token = value + + @property + def pad_token(self) -> str: + return "" + + @pad_token.setter + def pad_token(self, value: str): + self._pad_token = value + + @property + def pad_token_id(self): + return self.get_command("") + + @property + def eos_token(self) -> str: + return "" + + @eos_token.setter + def eos_token(self, value: str): + self._eos_token = value + + @property + def eos_token_id(self): + return self.get_command("") + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"]) + else: + vocab_file = save_directory + + with open(self.vocab_file, "rb") as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")] + return prefix_tokens + + def build_single_message(self, role, metadata, message): + assert role in ["system", "user", "assistant", "observation"], role + role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") + message_tokens = self.tokenizer.encode(message) + tokens = role_tokens + message_tokens + return tokens + + def build_chat_input(self, query, history=None, role="user"): + if history is None: + history = [] + input_ids = [] + for item in history: + content = item["content"] + if item["role"] == "system" and "tools" in item: + content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) + input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) + input_ids.extend(self.build_single_message(role, "", query)) + input_ids.extend([self.get_command("<|assistant|>")]) + return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) + + def build_inputs_with_special_tokens( + self, token_ids_0: list[int], token_ids_1: list[int] | None = None + ) -> list[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`list[int]`): + list of IDs to which the special tokens will be added. + token_ids_1 (`list[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `list[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")] + return token_ids_0 + + def _pad( + self, + encoded_inputs: dict[str, EncodedInput] | BatchEncoding, + max_length: int | None = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: int | None = None, + return_attention_mask: bool | None = None, + padding_side: bool | None = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * seq_length + + if "position_ids" not in encoded_inputs: + encoded_inputs["position_ids"] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f79d3c4773f393ed689a949041d36ad77e20968 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"] + _import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline + from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..6861438e4c63e6412d32575af405b01642ca5ddc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -0,0 +1,991 @@ +# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import LCMScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import AutoPipelineForImage2Image + >>> import torch + >>> import PIL + + >>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") + >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. + >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32) + + >>> prompt = "High altitude snowy mountains" + >>> image = PIL.Image.open("./snowy_mountains.png") + + >>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps. + >>> num_inference_steps = 4 + >>> images = pipe( + ... prompt=prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=8.0 + ... ).images + + >>> images[0].save("image.png") + ``` + +""" + + +class LatentConsistencyModelImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for image-to-image generation using a latent consistency model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only + supports [`LCMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + requires_safety_checker (`bool`, *optional*, defaults to `True`): + Whether the pipeline requires a safety checker component. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: LCMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt: str | list[str], + strength: float, + callback_steps: int, + prompt_embeds: torch.Tensor | None = None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 4, + strength: float = 0.8, + original_inference_steps: int = None, + timesteps: list[int] = None, + guidance_scale: float = 8.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + original_inference_steps (`int`, *optional*): + The original number of inference steps use to generate a linearly-spaced timestep schedule, from which + we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, + following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the + scheduler's `original_inference_steps` attribute. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending + order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + Note that the original latent consistency models paper uses a different CFG formulation where the + guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > + 0`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided + # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the + # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. + prompt_embeds, _ = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Encode image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + original_inference_steps=original_inference_steps, + strength=strength, + ) + + # 6. Prepare latent variables + original_inference_steps = ( + original_inference_steps + if original_inference_steps is not None + else self.scheduler.config.original_inference_steps + ) + latent_timestep = timesteps[:1] + if latents is None: + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + bs = batch_size * num_images_per_prompt + + # 6. Get Guidance Scale Embedding + # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper + # CFG formulation, so we need to subtract 1 from the input guidance_scale. + # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) + w = torch.tensor(self.guidance_scale - 1).repeat(bs) + w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( + device=device, dtype=latents.dtype + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 8. LCM Multistep Sampling Loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = latents.to(prompt_embeds.dtype) + + # model prediction (v-prediction, eps, x) + model_pred = self.unet( + latents, + t, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + w_embedding = callback_outputs.pop("w_embedding", w_embedding) + denoised = callback_outputs.pop("denoised", denoised) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + denoised = denoised.to(prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = denoised + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py new file mode 100644 index 0000000000000000000000000000000000000000..60f59ec7f9d3dd6416cacccb2ef11b665dbb8faf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -0,0 +1,925 @@ +# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import LCMScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import DiffusionPipeline + >>> import torch + + >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7") + >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. + >>> pipe.to(torch_device="cuda", torch_dtype=torch.float32) + + >>> prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" + + >>> # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps. + >>> num_inference_steps = 4 + >>> images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0).images + >>> images[0].save("image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LatentConsistencyModelPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using a latent consistency model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only + supports [`LCMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + requires_safety_checker (`bool`, *optional*, defaults to `True`): + Whether the pipeline requires a safety checker component. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: LCMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed + def check_inputs( + self, + prompt: str | list[str], + height: int, + width: int, + callback_steps: int, + prompt_embeds: torch.Tensor | None = None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 4, + original_inference_steps: int = None, + timesteps: list[int] = None, + guidance_scale: float = 8.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + original_inference_steps (`int`, *optional*): + The original number of inference steps use to generate a linearly-spaced timestep schedule, from which + we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, + following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the + scheduler's `original_inference_steps` attribute. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending + order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + Note that the original latent consistency models paper uses a different CFG formulation where the + guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale > + 0`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided + # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the + # unconditional prompt "" (the empty string). Due to this, LCMs currently do not support negative prompts. + prompt_embeds, _ = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + original_inference_steps=original_inference_steps, + ) + + # 5. Prepare latent variable + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + bs = batch_size * num_images_per_prompt + + # 6. Get Guidance Scale Embedding + # NOTE: We use the Imagen CFG formulation that StableDiffusionPipeline uses rather than the original LCM paper + # CFG formulation, so we need to subtract 1 from the input guidance_scale. + # LCM CFG formulation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond), (cfg_scale > 0.0 using CFG) + w = torch.tensor(self.guidance_scale - 1).repeat(bs) + w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=self.unet.config.time_cond_proj_dim).to( + device=device, dtype=latents.dtype + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 8. LCM MultiStep Sampling Loop: + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latents = latents.to(prompt_embeds.dtype) + + # model prediction (v-prediction, eps, x) + model_pred = self.unet( + latents, + t, + timestep_cond=w_embedding, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + w_embedding = callback_outputs.pop("w_embedding", w_embedding) + denoised = callback_outputs.pop("denoised", denoised) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + denoised = denoised.to(prompt_embeds.dtype) + if not output_type == "latent": + image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = denoised + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..561f96fc71dc7b4404e09571e0b7eaa4ee02fde8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latent_diffusion"] = ["LDMBertModel", "LDMTextToImagePipeline"] + _import_structure["pipeline_latent_diffusion_superresolution"] = ["LDMSuperResolutionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline + from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..ec43988f93892d0e719874f56dbae914354205ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -0,0 +1,745 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch +import torch.nn as nn +from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.utils import logging + +from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +class LDMTextToImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [`~transformers.BERT`]. + tokenizer ([`~transformers.BertTokenizer`]): + A `BertTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "bert->unet->vqvae" + + def __init__( + self, + vqvae: VQModel | AutoencoderKL, + bert: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + unet: UNet2DModel | UNet2DConditionModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + ): + super().__init__() + self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 1.0, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + **kwargs, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 1.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # load model and scheduler + >>> ldm = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> prompt = "A painting of a squirrel eating a burger" + >>> images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images + + >>> # save images + >>> for idx, image in enumerate(images): + ... image.save(f"squirrel-{idx}.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get unconditional embeddings for classifier free guidance + if guidance_scale != 1.0: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" + ) + negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self._execution_device))[0] + + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") + prompt_embeds = self.bert(text_input.input_ids.to(self._execution_device))[0] + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + latents_shape, generator=generator, device=self._execution_device, dtype=prompt_embeds.dtype + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self._execution_device) + + self.scheduler.set_timesteps(num_inference_steps) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = prompt_embeds + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + # scale and decode the image latents with vae + latents = 1 / self.vqvae.config.scaling_factor * latents + image = self.vqvae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_list = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor | None = None, + past_key_value: tuple[torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, + layer_head_mask: torch.Tensor | None = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Args: + hidden_states (`torch.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.Tensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.Tensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + _supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + head_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + ) -> tuple | BaseModelOutput: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if torch.is_grad_enabled() and self.gradient_checkpointing: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + _no_split_modules = [] + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py new file mode 100644 index 0000000000000000000000000000000000000000..18cb8274f9b5fefd481346c25297d069e60984aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -0,0 +1,196 @@ +import inspect + +import numpy as np +import PIL.Image +import torch + +from ...models import UNet2DModel, VQModel +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, is_torch_xla_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +def preprocess(image): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +class LDMSuperResolutionPipeline(DiffusionPipeline): + r""" + A pipeline for image super-resolution using latent diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) model to encode and decode images to and from latent representations. + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], + [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vqvae: VQModel, + unet: UNet2DModel, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + ): + super().__init__() + self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + image: torch.Tensor | PIL.Image.Image = None, + batch_size: int | None = 1, + num_inference_steps: int | None = 100, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ) -> tuple | ImagePipelineOutput: + r""" + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): + `Image` or tensor representing an image batch to be used as the starting point for the process. + batch_size (`int`, *optional*, defaults to 1): + Number of images to generate. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import LDMSuperResolutionPipeline + >>> import torch + + >>> # load model and scheduler + >>> pipeline = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages") + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = ( + ... "https://user-images.githubusercontent.com/38061659/199705896-b48e17b8-b231-47cd-a270-4ffa5a93fa3e.png" + ... ) + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + + >>> # run pipeline in inference (sample random noise and denoise) + >>> upscaled_image = pipeline(low_res_img, num_inference_steps=100, eta=1).images[0] + >>> # save image + >>> upscaled_image.save("ldm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if isinstance(image, PIL.Image.Image): + image = preprocess(image) + + height, width = image.shape[-2:] + + # in_channels should be 6: 3 for latents, 3 for low resolution image + latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) + latents_dtype = next(self.unet.parameters()).dtype + + latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + + image = image.to(device=self.device, dtype=latents_dtype) + + # set timesteps and move to the correct device + self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps_tensor = self.scheduler.timesteps + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_kwargs = {} + if accepts_eta: + extra_kwargs["eta"] = eta + + for t in self.progress_bar(timesteps_tensor): + # concat latents and low resolution image in the channel dimension. + latents_input = torch.cat([latents, image], dim=1) + latents_input = self.scheduler.scale_model_input(latents_input, t) + # predict the noise residual + noise_pred = self.unet(latents_input, t).sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + # decode the image latents with the VQVAE + image = self.vqvae.decode(latents).sample + image = torch.clamp(image, -1.0, 1.0) + image = image / 2 + 0.5 + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latte/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4296b42e125303c0e036b9f2deecc36bdf959de3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latte/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_latte"] = ["LattePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_latte import LattePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latte/pipeline_latte.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latte/pipeline_latte.py new file mode 100644 index 0000000000000000000000000000000000000000..eed7762cebf15a9dec28e709b27399660ff7d517 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/latte/pipeline_latte.py @@ -0,0 +1,914 @@ +# Copyright 2025 the Latte Team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from dataclasses import dataclass +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKL, LatteTransformer3DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + BaseOutput, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LattePipeline + >>> from diffusers.utils import export_to_gif + + >>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too. + >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> videos = pipe(prompt).frames[0] + >>> export_to_gif(videos, "latte.gif") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class LattePipelineOutput(BaseOutput): + frames: torch.Tensor + + +class LattePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Latte. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. Latte uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`LatteTransformer3DModel`]): + A text conditioned `LatteTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile(r"[#®•©™&@·º½¾¿¡§~\)\(\]\[\}\{\|\\/\\*]{1,}") + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: LatteTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 + else: + masked_feature = emb * mask[:, None, :, None] # 1 120 4096 + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + clean_caption: bool = False, + mask_feature: bool = True, + dtype=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Latte, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of video that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For Latte, it's should be the embeddings of the "" string. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + mask_feature: (bool, defaults to `True`): + If `True`, the function will mask the text embeddings. + """ + embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 120 + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds_attention_mask = attention_mask + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds_attention_mask = torch.ones_like(prompt_embeds) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1) + prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + + # Perform additional masking. + if mask_feature and not embeds_initially_provided: + prompt_embeds = prompt_embeds.unsqueeze(1) + masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask) + masked_prompt_embeds = masked_prompt_embeds.squeeze(1) + masked_negative_prompt_embeds = ( + negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None + ) + + return masked_prompt_embeds, masked_negative_prompt_embeds + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 7.5, + num_images_per_prompt: int = 1, + video_length: int = 16, + height: int = 512, + width: int = 512, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + clean_caption: bool = True, + mask_feature: bool = True, + enable_temporal_attentions: bool = True, + decode_chunk_size: int = 14, + ) -> LattePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + video_length (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For Latte this negative prompt should be "". If not provided, + negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + enable_temporal_attentions (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the + expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. + For lower memory usage, reduce `decode_chunk_size`. + + Examples: + + Returns: + [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.latte.pipeline_latte.LattePipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else video_length + + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clean_caption=clean_caption, + mask_feature=mask_feature, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + enable_temporal_attentions=enable_temporal_attentions, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # use learned sigma? + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous video: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latents": + deprecation_message = ( + "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." + ) + deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False) + output_type = "latent" + + if not output_type == "latent": + video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LattePipelineOutput(frames=video) + + # Similar to diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion.decode_latents + def decode_latents(self, latents: torch.Tensor, video_length: int, decode_chunk_size: int = 14): + # [batch, channels, frames, height, width] -> [batch*frames, channels, height, width] + latents = latents.permute(0, 2, 1, 3, 4).flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, video_length, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.float() + return frames diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aae3b1cb18ce96c7d79d3c44bc37cf2ae4db9720 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_leditspp_stable_diffusion"] = ["LEditsPPPipelineStableDiffusion"] + _import_structure["pipeline_leditspp_stable_diffusion_xl"] = ["LEditsPPPipelineStableDiffusionXL"] + + _import_structure["pipeline_output"] = ["LEditsPPDiffusionPipelineOutput", "LEditsPPDiffusionPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_leditspp_stable_diffusion import ( + LEditsPPDiffusionPipelineOutput, + LEditsPPInversionPipelineOutput, + LEditsPPPipelineStableDiffusion, + ) + from .pipeline_leditspp_stable_diffusion_xl import LEditsPPPipelineStableDiffusionXL + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..864f9feeb5aa03c8cadf6e8829ca02cbe81f9ecf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -0,0 +1,1589 @@ +import inspect +import math +from itertools import repeat +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import Attention, AttnProcessor +from ...models.lora import adjust_lora_scale_text_encoder +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import LEditsPPPipelineStableDiffusion + >>> from diffusers.utils import load_image + + >>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_vae_tiling() + >>> pipe = pipe.to("cuda") + + >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png" + >>> image = load_image(img_url).resize((512, 512)) + + >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1) + + >>> edited_image = pipe( + ... editing_prompt=["cherry blossom"], edit_guidance_scale=10.0, edit_threshold=0.75 + ... ).images[0] + ``` +""" + + +# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.AttentionStore +class LeditsAttentionStore: + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): + # attn.shape = batch_size * head_size, seq_len query, seq_len_key + if attn.shape[1] <= self.max_size: + bs = 1 + int(PnP) + editing_prompts + skip = 2 if PnP else 1 # skip PnP & unconditional + attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) + source_batch_size = int(attn.shape[1] // bs) + self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet) + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + + self.step_store[key].append(attn) + + def between_steps(self, store_step=True): + if store_step: + if self.average: + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + else: + if len(self.attention_store) == 0: + self.attention_store = [self.step_store] + else: + self.attention_store.append(self.step_store) + + self.cur_step += 1 + self.step_store = self.get_empty_store() + + def get_attention(self, step: int): + if self.average: + attention = { + key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store + } + else: + assert step is not None + attention = self.attention_store[step] + return attention + + def aggregate_attention( + self, attention_maps, prompts, res: int | tuple[int], from_where: list[str], is_cross: bool, select: int + ): + out = [[] for x in range(self.batch_size)] + if isinstance(res, int): + num_pixels = res**2 + resolution = (res, res) + else: + num_pixels = res[0] * res[1] + resolution = res[:2] + + for location in from_where: + for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + for batch, item in enumerate(bs_item): + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select] + out[batch].append(cross_maps) + + out = torch.stack([torch.cat(x, dim=0) for x in out]) + # average over heads + out = out.sum(1) / out.shape[1] + return out + + def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None): + self.step_store = self.get_empty_store() + self.attention_store = [] + self.cur_step = 0 + self.average = average + self.batch_size = batch_size + if max_size is None: + self.max_size = max_resolution**2 + elif max_size is not None and max_resolution is None: + self.max_size = max_size + else: + raise ValueError("Only allowed to set one of max_resolution or max_size") + + +# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing +class LeditsGaussianSmoothing: + def __init__(self, device): + kernel_size = [3, 3] + sigma = [0.5, 0.5] + + # The gaussian kernel is the product of the gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij") + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) + + self.weight = kernel.to(device) + + def __call__(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return F.conv2d(input, weight=self.weight.to(input.dtype)) + + +class LEDITSCrossAttnProcessor: + def __init__(self, attention_store, place_in_unet, pnp, editing_prompts): + self.attnstore = attention_store + self.place_in_unet = place_in_unet + self.editing_prompts = editing_prompts + self.pnp = pnp + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states, + attention_mask=None, + temb=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attnstore( + attention_probs, + is_cross=True, + place_in_unet=self.place_in_unet, + editing_prompts=self.editing_prompts, + PnP=self.pnp, + ) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LEditsPPPipelineStableDiffusion( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + """ + Pipeline for textual image editing using LEDits++ with Stable Diffusion. + + This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass + documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular + device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer ([`~transformers.CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will + automatically be set to [`DPMSolverMultistepScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | DPMSolverMultistepScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): + scheduler = DPMSolverMultistepScheduler.from_config( + scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2 + ) + logger.warning( + "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. " + "The scheduler has been changed to DPMSolverMultistepScheduler." + ) + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.inversion_steps = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, eta, generator=None): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + negative_prompt=None, + editing_prompt_embeddings=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if editing_prompt_embeddings is not None and negative_prompt_embeds is not None: + if editing_prompt_embeddings.shape != negative_prompt_embeds.shape: + raise ValueError( + "`editing_prompt_embeddings` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `editing_prompt_embeddings` {editing_prompt_embeddings.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents): + # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + + # if latents.shape != shape: + # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_unet(self, attention_store, PnP: bool = False): + attn_procs = {} + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + if "attn2" in name and place_in_unet != "mid": + attn_procs[name] = LEDITSCrossAttnProcessor( + attention_store=attention_store, + place_in_unet=place_in_unet, + pnp=PnP, + editing_prompts=self.enabled_editing_prompts, + ) + else: + attn_procs[name] = AttnProcessor() + + self.unet.set_attn_processor(attn_procs) + + def encode_prompt( + self, + device, + num_images_per_prompt, + enable_edit_guidance, + negative_prompt=None, + editing_prompt=None, + negative_prompt_embeds: torch.Tensor | None = None, + editing_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + enable_edit_guidance (`bool`): + whether to perform any editing or reconstruct the input image instead + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + editing_prompt (`str` or `list[str]`, *optional*): + Editing prompt(s) to be encoded. If not defined, one has to pass `editing_prompt_embeds` instead. + editing_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + batch_size = self.batch_size + num_edit_tokens = None + + if negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but exoected" + f"{batch_size} based on the input images. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = negative_prompt_embeds.dtype + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if enable_edit_guidance: + if editing_prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + # if isinstance(self, TextualInversionLoaderMixin): + # prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + + max_length = negative_prompt_embeds.shape[1] + text_inputs = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + return_length=True, + ) + + num_edit_tokens = text_inputs.length - 2 # not counting startoftext and endoftext + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="longest", + return_tensors="pt", + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + editing_prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + editing_prompt_embeds = editing_prompt_embeds[0] + else: + editing_prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + editing_prompt_embeds = editing_prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + editing_prompt_embeds = self.text_encoder.text_model.final_layer_norm(editing_prompt_embeds) + + editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device) + + bs_embed_edit, seq_len, _ = editing_prompt_embeds.shape + editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device) + editing_prompt_embeds = editing_prompt_embeds.repeat(1, num_images_per_prompt, 1) + editing_prompt_embeds = editing_prompt_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return editing_prompt_embeds, negative_prompt_embeds, num_edit_tokens + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + negative_prompt: str | list[str] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + editing_prompt: str | list[str] | None = None, + editing_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + reverse_editing_direction: bool | list[bool] | None = False, + edit_guidance_scale: float | list[float] | None = 5, + edit_warmup_steps: int | list[int] | None = 0, + edit_cooldown_steps: int | list[int] | None = None, + edit_threshold: float | list[float] | None = 0.9, + user_mask: torch.Tensor | None = None, + sem_guidance: list[torch.Tensor] | None = None, + use_cross_attn_mask: bool = False, + use_intersect_mask: bool = True, + attn_store_steps: list[int] | None = [], + store_averaged_over_steps: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for editing. The + [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusion.invert`] method has to be called beforehand. Edits will + always be performed for the last inverted image(s). + + Args: + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] instead of a plain + tuple. + editing_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. The image is reconstructed by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeds (`torch.Tensor>`, *optional*): + Pre-computed embeddings to use for guiding the image generation. Guidance direction of embedding should + be specified via `reverse_editing_direction`. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + reverse_editing_direction (`bool` or `list[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `list[float]`, *optional*, defaults to 5): + Guidance scale for guiding the image generation. If provided as list values should correspond to + `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++ + Paper](https://huggingface.co/papers/2301.12247). + edit_warmup_steps (`float` or `list[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which guidance will not be applied. + edit_cooldown_steps (`float` or `list[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which guidance will no longer be applied. + edit_threshold (`float` or `list[float]`, *optional*, defaults to 0.9): + Masking threshold of guidance. Threshold should be proportional to the image region that is modified. + 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ + Paper](https://huggingface.co/papers/2301.12247). + user_mask (`torch.Tensor`, *optional*): + User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s + implicit masks do not meet user preferences. + sem_guidance (`list[torch.Tensor]`, *optional*): + list of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + use_cross_attn_mask (`bool`, defaults to `False`): + Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask + is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++ + paper](https://huggingface.co/papers/2311.16711). + use_intersect_mask (`bool`, defaults to `True`): + Whether the masking term is calculated as intersection of cross-attention masks and masks derived from + the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate + are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://huggingface.co/papers/2311.16711). + attn_store_steps (`list[int]`, *optional*): + Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes. + store_averaged_over_steps (`bool`, defaults to `True`): + Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If + False, attention maps for each step are stores separately. Just for visualization purposes. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) + content, according to the `safety_checker`. + """ + + if self.inversion_steps is None: + raise ValueError( + "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)." + ) + + eta = self.eta + num_images_per_prompt = 1 + latents = self.init_latents + + zs = self.zs + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + + if use_intersect_mask: + use_cross_attn_mask = True + + if use_cross_attn_mask: + self.smoothing = LeditsGaussianSmoothing(self.device) + + if user_mask is not None: + user_mask = user_mask.to(self.device) + + org_prompt = "" + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + negative_prompt, + editing_prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + batch_size = self.batch_size + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + self.enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeds is not None: + enable_edit_guidance = True + self.enabled_editing_prompts = editing_prompt_embeds.shape[0] + else: + self.enabled_editing_prompts = 0 + enable_edit_guidance = False + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + edit_concepts, uncond_embeddings, num_edit_tokens = self.encode_prompt( + editing_prompt=editing_prompt, + device=self.device, + num_images_per_prompt=num_images_per_prompt, + enable_edit_guidance=enable_edit_guidance, + negative_prompt=negative_prompt, + editing_prompt_embeds=editing_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if enable_edit_guidance: + text_embeddings = torch.cat([uncond_embeddings, edit_concepts]) + self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt + else: + text_embeddings = torch.cat([uncond_embeddings]) + + # 4. Prepare timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=self.device) + timesteps = self.inversion_steps + t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0] :])} + + if use_cross_attn_mask: + self.attention_store = LeditsAttentionStore( + average=store_averaged_over_steps, + batch_size=batch_size, + max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0), + max_resolution=None, + ) + self.prepare_unet(self.attention_store, PnP=False) + resolution = latents.shape[-2:] + att_res = (int(resolution[0] / 4), int(resolution[1] / 4)) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + None, + None, + text_embeddings.dtype, + self.device, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(eta) + + self.sem_guidance = None + self.activation_mask = None + + # 7. Denoising loop + num_warmup_steps = 0 + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + + if enable_edit_guidance: + latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) + else: + latent_model_input = latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + text_embed_input = text_embeddings + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample + + noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond = noise_pred_out[0] + noise_pred_edit_concepts = noise_pred_out[1:] + + noise_guidance_edit = torch.zeros( + noise_pred_uncond.shape, + device=self.device, + dtype=noise_pred_uncond.dtype, + ) + + if sem_guidance is not None and len(sem_guidance) > i: + noise_guidance_edit += sem_guidance[i].to(self.device) + + elif enable_edit_guidance: + if self.activation_mask is None: + self.activation_mask = torch.zeros( + (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) + ) + + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) + + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + if i < edit_warmup_steps_c: + continue + + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + + if i >= edit_cooldown_steps_c: + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + if user_mask is not None: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask + + if use_cross_attn_mask: + out = self.attention_store.aggregate_attention( + attention_maps=self.attention_store.step_store, + prompts=self.text_cross_attention_maps, + res=att_res, + from_where=["up", "down"], + is_cross=True, + select=self.text_cross_attention_maps.index(editing_prompt[c]), + ) + attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext + + # average over all tokens + if attn_map.shape[3] != num_edit_tokens[c]: + raise ValueError( + f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!" + ) + + attn_map = torch.sum(attn_map, dim=3) + + # gaussian_smoothing + attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") + attn_map = self.smoothing(attn_map).squeeze(1) + + # torch.quantile function expects float32 + if attn_map.dtype == torch.float32: + tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) + else: + tmp = torch.quantile( + attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1 + ).to(attn_map.dtype) + attn_mask = torch.where( + attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0 + ) + + # resolution must match latent space dimension + attn_mask = F.interpolate( + attn_mask.unsqueeze(1), + noise_guidance_edit_tmp.shape[-2:], # 64,64 + ).repeat(1, 4, 1, 1) + self.activation_mask[i, c] = attn_mask.detach().cpu() + if not use_intersect_mask: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask + + if use_intersect_mask: + if t <= 800: + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat( + 1, self.unet.config.in_channels, 1, 1 + ) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + intersect_mask = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + * attn_mask + ) + + self.activation_mask[i, c] = intersect_mask.detach().cpu() + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask + + else: + # print(f"only attention mask for step {i}") + noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask + + elif not use_cross_attn_mask: + # calculate quantile + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + self.activation_mask[i, c] = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + .detach() + .cpu() + ) + + noise_guidance_edit_tmp = torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + + noise_guidance_edit += noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + noise_pred = noise_pred_uncond + noise_guidance_edit + + if enable_edit_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_edit_concepts.mean(dim=0, keepdim=False), + guidance_rescale=self.guidance_rescale, + ) + + idx = t_to_idx[int(t)] + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs + ).prev_sample + + # step callback + if use_cross_attn_mask: + store_step = i in attn_store_steps + self.attention_store.between_steps(store_step) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + def invert( + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale: float = 3.5, + num_inversion_steps: int = 30, + skip: float = 0.15, + generator: torch.Generator | None = None, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + height: int | None = None, + width: int | None = None, + resize_mode: str | None = "default", + crops_coords: tuple[int, int, int, int] | None = None, + ): + r""" + The function to the pipeline for image inversion as described by the [LEDITS++ + Paper](https://huggingface.co/papers/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the + inversion proposed by [edit-friendly DPDM](https://huggingface.co/papers/2304.06140) will be performed instead. + + Args: + image (`PipelineImageInput`): + Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect + ratio. + source_prompt (`str`, defaults to `""`): + Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled + if the `source_prompt` is `""`. + source_guidance_scale (`float`, defaults to `3.5`): + Strength of guidance during inversion. + num_inversion_steps (`int`, defaults to `30`): + Number of total performed inversion steps after discarding the initial `skip` steps. + skip (`float`, defaults to `0.15`): + Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values + will lead to stronger changes to the input image. `skip` has to be between `0` and `1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion + deterministic. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`list[tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) + and respective VAE reconstruction(s). + """ + if height is not None and height % 32 != 0 or width is not None and width % 32 != 0: + raise ValueError("height and width must be a factor of 32.") + # Reset attn processor, we do not want to store attn maps during inversion + self.unet.set_attn_processor(AttnProcessor()) + + self.eta = 1.0 + + self.scheduler.config.timestep_spacing = "leading" + self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip))) + self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:] + timesteps = self.inversion_steps + + # 1. encode image + x0, resized = self.encode_image( + image, + dtype=self.text_encoder.dtype, + height=height, + width=width, + resize_mode=resize_mode, + crops_coords=crops_coords, + ) + self.batch_size = x0.shape[0] + + # autoencoder reconstruction + image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] + image_rec = self.image_processor.postprocess(image_rec, output_type="pil") + + # 2. get embeddings + do_classifier_free_guidance = source_guidance_scale > 1.0 + + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + uncond_embedding, text_embeddings, _ = self.encode_prompt( + num_images_per_prompt=1, + device=self.device, + negative_prompt=None, + enable_edit_guidance=do_classifier_free_guidance, + editing_prompt=source_prompt, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + # 3. find zs and xts + variance_noise_shape = (num_inversion_steps, *x0.shape) + + # intermediate latents + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) + + for t in reversed(timesteps): + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) + xts[idx] = self.scheduler.add_noise(x0, noise, torch.Tensor([t])) + xts = torch.cat([x0.unsqueeze(0), xts], dim=0) + + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + # noise maps + zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype) + + with self.progress_bar(total=len(timesteps)) as progress_bar: + for t in timesteps: + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + # 1. predict noise residual + xt = xts[idx + 1] + + noise_pred = self.unet(xt, timestep=t, encoder_hidden_states=uncond_embedding).sample + + if not source_prompt == "": + noise_pred_cond = self.unet(xt, timestep=t, encoder_hidden_states=text_embeddings).sample + noise_pred = noise_pred + source_guidance_scale * (noise_pred_cond - noise_pred) + + xtm1 = xts[idx] + z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta) + zs[idx] = z + + # correction to avoid error accumulation + xts[idx] = xtm1_corrected + + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1) + zs = zs.flip(0) + self.zs = zs + + return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec) + + @torch.no_grad() + def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): + image = self.image_processor.preprocess( + image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + height, width = image.shape[-2:] + if height % 32 != 0 or width % 32 != 0: + raise ValueError( + "Image height and width must be a factor of 32. " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + resized = self.image_processor.postprocess(image=image, output_type="pil") + + if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: + logger.warning( + "Your input images far exceed the default resolution of the underlying diffusion model. " + "The output images may contain severe artifacts! " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + image = image.to(dtype) + + x0 = self.vae.encode(image.to(self.device)).latent_dist.mode() + x0 = x0.to(dtype) + x0 = self.vae.config.scaling_factor * x0 + return x0, resized + + +def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + # modified so that updated xtm1 is returned as well (to avoid error accumulation) + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + if variance > 0.0: + noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) + else: + noise = torch.tensor([0.0]).to(latents.device) + + return noise, mu_xt + (eta * variance**0.5) * noise + + +def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): + def first_order_update(model_output, sample): # timestep, prev_timestep, sample): + sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index] + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + + mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + mu_xt = scheduler.dpm_solver_first_order_update( + model_output=model_output, sample=sample, noise=torch.zeros_like(sample) + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + return noise, prev_sample + + def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample): + sigma_t, sigma_s0, sigma_s1 = ( + scheduler.sigmas[scheduler.step_index + 1], + scheduler.sigmas[scheduler.step_index], + scheduler.sigmas[scheduler.step_index - 1], + ) + + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + mu_xt = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + + return noise, prev_sample + + if scheduler.step_index is None: + scheduler._init_step_index(timestep) + + model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents) + for i in range(scheduler.config.solver_order - 1): + scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] + scheduler.model_outputs[-1] = model_output + + if scheduler.lower_order_nums < 1: + noise, prev_sample = first_order_update(model_output, latents) + else: + noise, prev_sample = second_order_update(scheduler.model_outputs, latents) + + if scheduler.lower_order_nums < scheduler.config.solver_order: + scheduler.lower_order_nums += 1 + + # upon completion increase step index by one + scheduler._step_index += 1 + + return noise, prev_sample + + +def compute_noise(scheduler, *args): + if isinstance(scheduler, DDIMScheduler): + return compute_noise_ddim(scheduler, *args) + elif ( + isinstance(scheduler, DPMSolverMultistepScheduler) + and scheduler.config.algorithm_type == "sde-dpmsolver++" + and scheduler.config.solver_order == 2 + ): + return compute_noise_sde_dpm_pp_2nd(scheduler, *args) + else: + raise NotImplementedError diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..a136770b9f26c73fb1d2885d5980dc9029bf06ff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -0,0 +1,1866 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import ( + Attention, + AttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import LEditsPPPipelineStableDiffusionXL + >>> from diffusers.utils import load_image + + >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe.enable_vae_tiling() + >>> pipe = pipe.to("cuda") + + >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" + >>> image = load_image(img_url).resize((1024, 1024)) + + >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2) + + >>> edited_image = pipe( + ... editing_prompt=["tennis ball", "tomato"], + ... reverse_editing_direction=[True, False], + ... edit_guidance_scale=[5.0, 10.0], + ... edit_threshold=[0.9, 0.85], + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsAttentionStore +class LeditsAttentionStore: + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False): + # attn.shape = batch_size * head_size, seq_len query, seq_len_key + if attn.shape[1] <= self.max_size: + bs = 1 + int(PnP) + editing_prompts + skip = 2 if PnP else 1 # skip PnP & unconditional + attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3) + source_batch_size = int(attn.shape[1] // bs) + self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet) + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + + self.step_store[key].append(attn) + + def between_steps(self, store_step=True): + if store_step: + if self.average: + if len(self.attention_store) == 0: + self.attention_store = self.step_store + else: + for key in self.attention_store: + for i in range(len(self.attention_store[key])): + self.attention_store[key][i] += self.step_store[key][i] + else: + if len(self.attention_store) == 0: + self.attention_store = [self.step_store] + else: + self.attention_store.append(self.step_store) + + self.cur_step += 1 + self.step_store = self.get_empty_store() + + def get_attention(self, step: int): + if self.average: + attention = { + key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store + } + else: + assert step is not None + attention = self.attention_store[step] + return attention + + def aggregate_attention( + self, attention_maps, prompts, res: int | tuple[int], from_where: list[str], is_cross: bool, select: int + ): + out = [[] for x in range(self.batch_size)] + if isinstance(res, int): + num_pixels = res**2 + resolution = (res, res) + else: + num_pixels = res[0] * res[1] + resolution = res[:2] + + for location in from_where: + for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + for batch, item in enumerate(bs_item): + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select] + out[batch].append(cross_maps) + + out = torch.stack([torch.cat(x, dim=0) for x in out]) + # average over heads + out = out.sum(1) / out.shape[1] + return out + + def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None): + self.step_store = self.get_empty_store() + self.attention_store = [] + self.cur_step = 0 + self.average = average + self.batch_size = batch_size + if max_size is None: + self.max_size = max_resolution**2 + elif max_size is not None and max_resolution is None: + self.max_size = max_size + else: + raise ValueError("Only allowed to set one of max_resolution or max_size") + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsGaussianSmoothing +class LeditsGaussianSmoothing: + def __init__(self, device): + kernel_size = [3, 3] + sigma = [0.5, 0.5] + + # The gaussian kernel is the product of the gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij") + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1)) + + self.weight = kernel.to(device) + + def __call__(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return F.conv2d(input, weight=self.weight.to(input.dtype)) + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEDITSCrossAttnProcessor +class LEDITSCrossAttnProcessor: + def __init__(self, attention_store, place_in_unet, pnp, editing_prompts): + self.attnstore = attention_store + self.place_in_unet = place_in_unet + self.editing_prompts = editing_prompts + self.pnp = pnp + + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states, + attention_mask=None, + temb=None, + ): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + self.attnstore( + attention_probs, + is_cross=True, + place_in_unet=self.place_in_unet, + editing_prompts=self.editing_prompts, + PnP=self.pnp, + ) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class LEditsPPPipelineStableDiffusionXL( + DiffusionPipeline, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + """ + Pipeline for textual image editing using LEDits++ with Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the + superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a + particular device, etc.). + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`LEditsPPPipelineStableDiffusionXL.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer ([`~transformers.CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will + automatically be set to [`DPMSolverMultistepScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DPMSolverMultistepScheduler | DDIMScheduler, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): + self.scheduler = DPMSolverMultistepScheduler.from_config( + scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2 + ) + logger.warning( + "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. " + "The scheduler has been changed to DPMSolverMultistepScheduler." + ) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + self.inversion_steps = None + + def encode_prompt( + self, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + enable_edit_guidance: bool = True, + editing_prompt: str | None = None, + editing_prompt_embeds: torch.Tensor | None = None, + editing_pooled_prompt_embeds: torch.Tensor | None = None, + ) -> object: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + enable_edit_guidance (`bool`): + Whether to guide towards an editing prompt or not. + editing_prompt (`str` or `list[str]`, *optional*): + Editing prompt(s) to be encoded. If not defined and 'enable_edit_guidance' is True, one has to pass + `editing_prompt_embeds` instead. + editing_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from + `editing_prompt` input argument. + editing_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated edit pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled editing_pooled_prompt_embeds will be generated from `editing_prompt` + input argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + batch_size = self.batch_size + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + num_edit_tokens = 0 + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but image inversion " + f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of the input images." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(negative_prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(negative_pooled_prompt_embeds) + + if enable_edit_guidance and editing_prompt_embeds is None: + editing_prompt_2 = editing_prompt + + editing_prompts = [editing_prompt, editing_prompt_2] + edit_prompt_embeds_list = [] + + for editing_prompt, tokenizer, text_encoder in zip(editing_prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + editing_prompt = self.maybe_convert_prompt(editing_prompt, tokenizer) + + max_length = negative_prompt_embeds.shape[1] + edit_concepts_input = tokenizer( + # [x for item in editing_prompt for x in repeat(item, batch_size)], + editing_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + return_length=True, + ) + num_edit_tokens = edit_concepts_input.length - 2 + + edit_concepts_embeds = text_encoder( + edit_concepts_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + editing_pooled_prompt_embeds = edit_concepts_embeds[0] + if clip_skip is None: + edit_concepts_embeds = edit_concepts_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)] + + edit_prompt_embeds_list.append(edit_concepts_embeds) + + edit_concepts_embeds = torch.concat(edit_prompt_embeds_list, dim=-1) + elif not enable_edit_guidance: + edit_concepts_embeds = None + editing_pooled_prompt_embeds = None + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = negative_prompt_embeds.shape + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if enable_edit_guidance: + bs_embed_edit, seq_len, _ = edit_concepts_embeds.shape + edit_concepts_embeds = edit_concepts_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + edit_concepts_embeds = edit_concepts_embeds.repeat(1, num_images_per_prompt, 1) + edit_concepts_embeds = edit_concepts_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1) + + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if enable_edit_guidance: + editing_pooled_prompt_embeds = editing_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed_edit * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return ( + negative_prompt_embeds, + edit_concepts_embeds, + negative_pooled_prompt_embeds, + editing_pooled_prompt_embeds, + num_edit_tokens, + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, eta, generator=None): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, device, latents): + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet + def prepare_unet(self, attention_store, PnP: bool = False): + attn_procs = {} + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + if "attn2" in name and place_in_unet != "mid": + attn_procs[name] = LEDITSCrossAttnProcessor( + attention_store=attention_store, + place_in_unet=place_in_unet, + pnp=PnP, + editing_prompts=self.enabled_editing_prompts, + ) + else: + attn_procs[name] = AttnProcessor() + + self.unet.set_attn_processor(attn_procs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + denoising_end: float | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + editing_prompt: str | list[str] | None = None, + editing_prompt_embeddings: torch.Tensor | None = None, + editing_pooled_prompt_embeds: torch.Tensor | None = None, + reverse_editing_direction: bool | list[bool] | None = False, + edit_guidance_scale: float | list[float] | None = 5, + edit_warmup_steps: int | list[int] | None = 0, + edit_cooldown_steps: int | list[int] | None = None, + edit_threshold: float | list[float] | None = 0.9, + sem_guidance: list[torch.Tensor] | None = None, + use_cross_attn_mask: bool = False, + use_intersect_mask: bool = False, + user_mask: torch.Tensor | None = None, + attn_store_steps: list[int] | None = [], + store_averaged_over_steps: bool = True, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for editing. The + [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`] method has to be called beforehand. Edits + will always be performed for the last inverted image(s). + + Args: + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + editing_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. The image is reconstructed by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeddings (`torch.Tensor`, *optional*): + Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input argument. + editing_pooled_prompt_embeddings (`torch.Tensor`, *optional*): + Pre-generated pooled edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input + argument. + reverse_editing_direction (`bool` or `list[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `list[float]`, *optional*, defaults to 5): + Guidance scale for guiding the image generation. If provided as list values should correspond to + `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++ + Paper](https://huggingface.co/papers/2301.12247). + edit_warmup_steps (`float` or `list[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which guidance is not applied. + edit_cooldown_steps (`float` or `list[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which guidance is no longer applied. + edit_threshold (`float` or `list[float]`, *optional*, defaults to 0.9): + Masking threshold of guidance. Threshold should be proportional to the image region that is modified. + 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ + Paper](https://huggingface.co/papers/2301.12247). + sem_guidance (`list[torch.Tensor]`, *optional*): + list of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + use_cross_attn_mask: + Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask + is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++ + paper](https://huggingface.co/papers/2311.16711). + use_intersect_mask: + Whether the masking term is calculated as intersection of cross-attention masks and masks derived from + the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate + are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://huggingface.co/papers/2311.16711). + user_mask: + User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s + implicit masks do not meet user preferences. + attn_store_steps: + Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes. + store_averaged_over_steps: + Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If + False, attention maps for each step are stores separately. Just for visualization purposes. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + if self.inversion_steps is None: + raise ValueError( + "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)." + ) + + eta = self.eta + num_images_per_prompt = 1 + latents = self.init_latents + + zs = self.zs + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + + if use_intersect_mask: + use_cross_attn_mask = True + + if use_cross_attn_mask: + self.smoothing = LeditsGaussianSmoothing(self.device) + + if user_mask is not None: + user_mask = user_mask.to(self.device) + + # TODO: Check inputs + # 1. Check inputs. Raise error if not correct + # self.check_inputs( + # callback_steps, + # negative_prompt, + # negative_prompt_2, + # prompt_embeds, + # negative_prompt_embeds, + # pooled_prompt_embeds, + # negative_pooled_prompt_embeds, + # ) + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + batch_size = self.batch_size + + device = self._execution_device + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + self.enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeddings is not None: + enable_edit_guidance = True + self.enabled_editing_prompts = editing_prompt_embeddings.shape[0] + else: + self.enabled_editing_prompts = 0 + enable_edit_guidance = False + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + edit_prompt_embeds, + negative_pooled_prompt_embeds, + pooled_edit_embeds, + num_edit_tokens, + ) = self.encode_prompt( + device=device, + num_images_per_prompt=num_images_per_prompt, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + enable_edit_guidance=enable_edit_guidance, + editing_prompt=editing_prompt, + editing_prompt_embeds=editing_prompt_embeddings, + editing_pooled_prompt_embeds=editing_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.inversion_steps + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + + if use_cross_attn_mask: + self.attention_store = LeditsAttentionStore( + average=store_averaged_over_steps, + batch_size=batch_size, + max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0), + max_resolution=None, + ) + self.prepare_unet(self.attention_store) + resolution = latents.shape[-2:] + att_res = (int(resolution[0] / 4), int(resolution[1] / 4)) + + # 5. Prepare latent variables + latents = self.prepare_latents(device=device, latents=latents) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(eta) + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + # 7. Prepare added time ids & embeddings + add_text_embeds = negative_pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + self.size, + crops_coords_top_left, + self.size, + dtype=negative_pooled_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if enable_edit_guidance: + prompt_embeds = torch.cat([prompt_embeds, edit_prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, pooled_edit_embeds], dim=0) + edit_concepts_time_ids = add_time_ids.repeat(edit_prompt_embeds.shape[0], 1) + add_time_ids = torch.cat([add_time_ids, edit_concepts_time_ids], dim=0) + self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None: + # TODO: fix image encoding + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = image_embeds.to(device) + + # 8. Denoising loop + self.sem_guidance = None + self.activation_mask = None + + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond = noise_pred_out[0] + noise_pred_edit_concepts = noise_pred_out[1:] + + noise_guidance_edit = torch.zeros( + noise_pred_uncond.shape, + device=self.device, + dtype=noise_pred_uncond.dtype, + ) + + if sem_guidance is not None and len(sem_guidance) > i: + noise_guidance_edit += sem_guidance[i].to(self.device) + + elif enable_edit_guidance: + if self.activation_mask is None: + self.activation_mask = torch.zeros( + (len(timesteps), self.enabled_editing_prompts, *noise_pred_edit_concepts[0].shape) + ) + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape)) + + # noise_guidance_edit = torch.zeros_like(noise_guidance) + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + if i < edit_warmup_steps_c: + continue + + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + + if i >= edit_cooldown_steps_c: + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + if user_mask is not None: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask + + if use_cross_attn_mask: + out = self.attention_store.aggregate_attention( + attention_maps=self.attention_store.step_store, + prompts=self.text_cross_attention_maps, + res=att_res, + from_where=["up", "down"], + is_cross=True, + select=self.text_cross_attention_maps.index(editing_prompt[c]), + ) + attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext + + # average over all tokens + if attn_map.shape[3] != num_edit_tokens[c]: + raise ValueError( + f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!" + ) + attn_map = torch.sum(attn_map, dim=3) + + # gaussian_smoothing + attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect") + attn_map = self.smoothing(attn_map).squeeze(1) + + # torch.quantile function expects float32 + if attn_map.dtype == torch.float32: + tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1) + else: + tmp = torch.quantile( + attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1 + ).to(attn_map.dtype) + attn_mask = torch.where( + attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0 + ) + + # resolution must match latent space dimension + attn_mask = F.interpolate( + attn_mask.unsqueeze(1), + noise_guidance_edit_tmp.shape[-2:], # 64,64 + ).repeat(1, 4, 1, 1) + self.activation_mask[i, c] = attn_mask.detach().cpu() + if not use_intersect_mask: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask + + if use_intersect_mask: + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat( + 1, self.unet.config.in_channels, 1, 1 + ) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + intersect_mask = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + * attn_mask + ) + + self.activation_mask[i, c] = intersect_mask.detach().cpu() + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask + + elif not use_cross_attn_mask: + # calculate quantile + noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp) + noise_guidance_edit_tmp_quantile = torch.sum( + noise_guidance_edit_tmp_quantile, dim=1, keepdim=True + ) + noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1) + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp_quantile.dtype == torch.float32: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp_quantile.dtype) + + self.activation_mask[i, c] = ( + torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + torch.ones_like(noise_guidance_edit_tmp), + torch.zeros_like(noise_guidance_edit_tmp), + ) + .detach() + .cpu() + ) + + noise_guidance_edit_tmp = torch.where( + noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + + noise_guidance_edit += noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + noise_pred = noise_pred_uncond + noise_guidance_edit + + # compute the previous noisy sample x_t -> x_t-1 + if enable_edit_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_edit_concepts.mean(dim=0, keepdim=False), + guidance_rescale=self.guidance_rescale, + ) + + idx = t_to_idx[int(t)] + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs, return_dict=False + )[0] + + # step callback + if use_cross_attn_mask: + store_step = i in attn_store_steps + self.attention_store.between_steps(store_step) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + # negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=None) + + @torch.no_grad() + # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image + def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None): + image = self.image_processor.preprocess( + image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + height, width = image.shape[-2:] + if height % 32 != 0 or width % 32 != 0: + raise ValueError( + "Image height and width must be a factor of 32. " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + resized = self.image_processor.postprocess(image=image, output_type="pil") + + if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: + logger.warning( + "Your input images far exceed the default resolution of the underlying diffusion model. " + "The output images may contain severe artifacts! " + "Consider down-sampling the input using the `height` and `width` parameters" + ) + image = image.to(self.device, dtype=dtype) + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + image = image.float() + self.upcast_vae() + + x0 = self.vae.encode(image).latent_dist.mode() + x0 = x0.to(dtype) + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + x0 = self.vae.config.scaling_factor * x0 + return x0, resized + + @torch.no_grad() + def invert( + self, + image: PipelineImageInput, + source_prompt: str = "", + source_guidance_scale=3.5, + negative_prompt: str = None, + negative_prompt_2: str = None, + num_inversion_steps: int = 50, + skip: float = 0.15, + generator: torch.Generator | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + num_zero_noise_steps: int = 3, + cross_attention_kwargs: dict[str, Any] | None = None, + height: int | None = None, + width: int | None = None, + resize_mode: str | None = "default", + crops_coords: tuple[int, int, int, int] | None = None, + ): + r""" + The function to the pipeline for image inversion as described by the [LEDITS++ + Paper](https://huggingface.co/papers/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the + inversion proposed by [edit-friendly DPDM](https://huggingface.co/papers/2304.06140) will be performed instead. + + Args: + image (`PipelineImageInput`): + Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect + ratio. + source_prompt (`str`, defaults to `""`): + Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled + if the `source_prompt` is `""`. + source_guidance_scale (`float`, defaults to `3.5`): + Strength of guidance during inversion. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_inversion_steps (`int`, defaults to `50`): + Number of total performed inversion steps after discarding the initial `skip` steps. + skip (`float`, defaults to `0.15`): + Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values + will lead to stronger changes to the input image. `skip` has to be between `0` and `1`. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion + deterministic. + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + num_zero_noise_steps (`int`, defaults to `3`): + Number of final diffusion steps that will not renoise the current image. If no steps are set to zero + SD-XL in combination with [`DPMSolverMultistepScheduler`] will produce noise artifacts. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Returns: + [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) + and respective VAE reconstruction(s). + """ + if height is not None and height % 32 != 0 or width is not None and width % 32 != 0: + raise ValueError("height and width must be a factor of 32.") + + # Reset attn processor, we do not want to store attn maps during inversion + self.unet.set_attn_processor(AttnProcessor()) + + self.eta = 1.0 + + self.scheduler.config.timestep_spacing = "leading" + self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip))) + self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:] + timesteps = self.inversion_steps + + num_images_per_prompt = 1 + + device = self._execution_device + + # 0. Ensure that only uncond embedding is used if prompt = "" + if source_prompt == "": + # noise pred should only be noise_pred_uncond + source_guidance_scale = 0.0 + do_classifier_free_guidance = False + else: + do_classifier_free_guidance = source_guidance_scale > 1.0 + + # 1. prepare image + x0, resized = self.encode_image( + image, + dtype=self.text_encoder_2.dtype, + height=height, + width=width, + resize_mode=resize_mode, + crops_coords=crops_coords, + ) + width = x0.shape[2] * self.vae_scale_factor + height = x0.shape[3] * self.vae_scale_factor + self.size = (height, width) + + self.batch_size = x0.shape[0] + + # 2. get embeddings + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + if isinstance(source_prompt, str): + source_prompt = [source_prompt] * self.batch_size + + ( + negative_prompt_embeds, + prompt_embeds, + negative_pooled_prompt_embeds, + edit_pooled_prompt_embeds, + _, + ) = self.encode_prompt( + device=device, + num_images_per_prompt=num_images_per_prompt, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + editing_prompt=source_prompt, + lora_scale=text_encoder_lora_scale, + enable_edit_guidance=do_classifier_free_guidance, + ) + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + # 3. Prepare added time ids & embeddings + add_text_embeds = negative_pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + self.size, + crops_coords_top_left, + self.size, + dtype=negative_prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([add_text_embeds, edit_pooled_prompt_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + negative_prompt_embeds = negative_prompt_embeds.to(device) + + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(self.batch_size * num_images_per_prompt, 1) + + # autoencoder reconstruction + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image_rec = self.vae.decode( + x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator + )[0] + elif self.vae.config.force_upcast: + x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image_rec = self.vae.decode( + x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator + )[0] + else: + image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] + + image_rec = self.image_processor.postprocess(image_rec, output_type="pil") + + # 5. find zs and xts + variance_noise_shape = (num_inversion_steps, *x0.shape) + + # intermediate latents + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype) + + for t in reversed(timesteps): + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype) + xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0)) + xts = torch.cat([x0.unsqueeze(0), xts], dim=0) + + # noise maps + zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype) + + self.scheduler.set_timesteps(len(self.scheduler.timesteps)) + + for t in self.progress_bar(timesteps): + idx = num_inversion_steps - t_to_idx[int(t)] - 1 + # 1. predict noise residual + xt = xts[idx + 1] + + latent_model_input = torch.cat([xt] * 2) if do_classifier_free_guidance else xt + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=negative_prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # 2. perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + noise_pred = noise_pred_uncond + source_guidance_scale * (noise_pred_text - noise_pred_uncond) + + xtm1 = xts[idx] + z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta) + zs[idx] = z + + # correction to avoid error accumulation + xts[idx] = xtm1_corrected + + self.init_latents = xts[-1] + zs = zs.flip(0) + + if num_zero_noise_steps > 0: + zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:]) + self.zs = zs + return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec) + + +# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_ddim +def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta): + # 1. get previous step value (=t-1) + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if scheduler.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = scheduler._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred + + # modified so that updated xtm1 is returned as well (to avoid error accumulation) + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + if variance > 0.0: + noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta) + else: + noise = torch.tensor([0.0]).to(latents.device) + + return noise, mu_xt + (eta * variance**0.5) * noise + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd +def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta): + def first_order_update(model_output, sample): # timestep, prev_timestep, sample): + sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index] + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + + mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + mu_xt = scheduler.dpm_solver_first_order_update( + model_output=model_output, sample=sample, noise=torch.zeros_like(sample) + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + return noise, prev_sample + + def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample): + sigma_t, sigma_s0, sigma_s1 = ( + scheduler.sigmas[scheduler.step_index + 1], + scheduler.sigmas[scheduler.step_index], + scheduler.sigmas[scheduler.step_index - 1], + ) + + alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + mu_xt = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + ) + + sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) + if sigma > 0.0: + noise = (prev_latents - mu_xt) / sigma + else: + noise = torch.tensor([0.0]).to(sample.device) + + prev_sample = mu_xt + sigma * noise + + return noise, prev_sample + + if scheduler.step_index is None: + scheduler._init_step_index(timestep) + + model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents) + for i in range(scheduler.config.solver_order - 1): + scheduler.model_outputs[i] = scheduler.model_outputs[i + 1] + scheduler.model_outputs[-1] = model_output + + if scheduler.lower_order_nums < 1: + noise, prev_sample = first_order_update(model_output, latents) + else: + noise, prev_sample = second_order_update(scheduler.model_outputs, latents) + + if scheduler.lower_order_nums < scheduler.config.solver_order: + scheduler.lower_order_nums += 1 + + # upon completion increase step index by one + scheduler._step_index += 1 + + return noise, prev_sample + + +# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise +def compute_noise(scheduler, *args): + if isinstance(scheduler, DDIMScheduler): + return compute_noise_ddim(scheduler, *args) + elif ( + isinstance(scheduler, DPMSolverMultistepScheduler) + and scheduler.config.algorithm_type == "sde-dpmsolver++" + and scheduler.config.solver_order == 2 + ): + return compute_noise_sde_dpm_pp_2nd(scheduler, *args) + else: + raise NotImplementedError diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4a69e7442bdc9d7e46a060e8033654d714dc4f8e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ledits_pp/pipeline_output.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class LEditsPPDiffusionPipelineOutput(BaseOutput): + """ + Output class for LEdits++ Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`list[bool]`) + list indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + + +@dataclass +class LEditsPPInversionPipelineOutput(BaseOutput): + """ + Output class for LEdits++ Diffusion pipelines. + + Args: + input_images (`list[PIL.Image.Image]` or `np.ndarray`) + list of the cropped and resized input images as PIL images of length `batch_size` or NumPy array of shape ` + (batch_size, height, width, num_channels)`. + vae_reconstruction_images (`list[PIL.Image.Image]` or `np.ndarray`) + list of VAE reconstruction of all input images as PIL images of length `batch_size` or NumPy array of shape + ` (batch_size, height, width, num_channels)`. + """ + + images: list[PIL.Image.Image] | np.ndarray + vae_reconstruction_images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4bb0e5819c8bcecb0024e12cee72773edd21111 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_longcat_image"] = ["LongCatImagePipeline"] + _import_structure["pipeline_longcat_image_edit"] = ["LongCatImageEditPipeline"] + _import_structure["pipeline_output"] = ["LongCatImagePipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_longcat_image import LongCatImagePipeline + from .pipeline_longcat_image_edit import LongCatImageEditPipeline + from .pipeline_output import LongCatImagePipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_longcat_image.py new file mode 100644 index 0000000000000000000000000000000000000000..19720d7bbab80e85ab1cc08e34ecdc841bd66263 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -0,0 +1,666 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import re +from typing import Any + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LongCatImagePipelineOutput +from .system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LongCatImagePipeline + + >>> pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。" + >>> image = pipe( + ... prompt, + ... height=768, + ... width=1344, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... enable_cfg_renorm=True, + ... ).images[0] + >>> image.save("longcat_image.png") + ``` +""" + + +def get_prompt_language(prompt): + pattern = re.compile(r"[\u4e00-\u9fff]") + if bool(pattern.search(prompt)): + return "zh" + return "en" + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + print('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + The pipeline for text-to-image generation. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def rewire_prompt(self, prompt, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + all_text = [] + for each_prompt in prompt: + language = get_prompt_language(each_prompt) + if language == "zh": + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:" + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + # Preparation for inference + text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + all_text.append(text) + + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device) + + generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids.to(device) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = self.text_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + rewrite_prompt = output_text + return rewrite_prompt + + def _encode_prompt(self, prompt: list[str]): + batch_all_tokens = [] + + for each_prompt in prompt: + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(each_prompt): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + batch_all_tokens.append(all_tokens) + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": batch_all_tokens}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + batch_size = text_tokens_and_mask.input_ids.size(0) + + prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1) + suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + + input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1) + attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1) + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds.to(self.device), text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(self.tokenizer_max_length, self.tokenizer_max_length), + height=height // 2, + width=width // 2, + ).to(device) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device) + latents = latents.to(dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + enable_cfg_renorm: bool | None = True, + cfg_renorm_min: float | None = 0.0, + enable_prompt_rewrite: bool | None = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + enable_cfg_renorm: Whether to enable cfg_renorm. Enabling cfg_renorm will improve image quality, + but it may lead to a decrease in the stability of some image outputs.. + cfg_renorm_min: The minimum value of the cfg_renorm_scale range (0-1). + cfg_renorm_min = 1.0, renorm has no effect, while cfg_renorm_min=0.0, the renorm range is larger. + enable_prompt_rewrite: whether to enable prompt rewrite. + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if enable_prompt_rewrite: + prompt = self.rewire_prompt(prompt, device) + logger.info(f"Rewrite prompt {prompt}!") + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if enable_cfg_renorm: + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = noise_pred * scale + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..69d5d82f18ecd94937cad4bca1f2b8dfc0d5c23e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -0,0 +1,727 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +import re +from typing import Any + +import numpy as np +import PIL +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LongCatImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import LongCatImageEditPipeline + + >>> pipe = LongCatImageEditPipeline.from_pretrained( + ... "meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "change the cat to dog." + >>> input_image = Image.open("test.jpg").convert("RGB") + >>> image = pipe( + ... input_image, + ... prompt, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... ).images[0] + >>> image.save("longcat_image_edit.png") + ``` +""" + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.prepare_pos_ids +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + print('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = width if width % 16 == 0 else (width // 16 + 1) * 16 + height = height if height % 16 == 0 else (height // 16 + 1) * 16 + + width = int(width) + height = int(height) + + return width, height + + +class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + The LongCat-Image-Edit pipeline for image editing. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor_vl = text_processor.image_processor + + self.image_token = "<|image_pad|>" + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def _encode_prompt(self, prompt, image): + raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt") + pixel_values = raw_vl_input["pixel_values"] + image_grid_thw = raw_vl_input["image_grid_thw"] + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(prompt[0]): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": [all_tokens]}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + text = self.prompt_template_encode_prefix + + merge_length = self.image_processor_vl.merge_size**2 + while self.image_token in text: + num_image_tokens = image_grid_thw.prod() // merge_length + text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + text = text.replace("<|placeholder|>", self.image_token) + + prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + prefix_len = prefix_tokens.index(vision_start_token_id) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1) + attention_mask = torch.cat( + (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 + ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + pixel_values = pixel_values.to(self.device) + image_grid_thw = image_grid_thw.to(self.device) + + text_output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + ) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: list[str] = None, + image: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + prompt_embeds: torch.Tensor | None = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds, text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + prompt_embeds_length, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + image_latents, image_latents_ids = None, None + + if image is not None: + image = image.to(device=self.device, dtype=dtype) + + if image.shape[1] != self.vae.config.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + + image_latents_ids = prepare_pos_ids( + modality_id=2, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device, dtype=torch.float64) + + shape = (batch_size, num_channels_latents, height, width) + latents_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latents_ids, image_latents_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None: + if isinstance(prompt, str): + pass + elif isinstance(prompt, list) and len(prompt) == 1: + pass + else: + raise ValueError( + f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + calculated_height, + calculated_width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + image=prompt_image, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + calculated_height, + calculated_width, + prompt_embeds.dtype, + prompt_embeds.shape[1], + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + if image is not None: + latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) + else: + latent_image_ids = latents_ids + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_text = noise_pred_text[:, :image_seq_len] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_text + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..290178eb1a8859b682e64f44f582f86a5b236ab1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class LongCatImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/system_messages.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/system_messages.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b2318e4e813970ebba7e5da4a07c26f238f4e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/longcat_image/system_messages.py @@ -0,0 +1,142 @@ +SYSTEM_PROMPT_EN = """ +You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in +understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's +understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all +information from the user's original prompt without deleting or distorting any details. Specific requirements are as +follows: +1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use + coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as + concise as possible. +2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields + English output. The rewritten token count should not exceed 512. +3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the + original prompt, such as lighting and textures. +4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography + style**. If the user specifies a style, retain the user's style. +5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge + to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a + giraffe"). +6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% + OFF"`). +7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no + specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For + example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer + with the image title 'Grassland'." +8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For + example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. +9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**. +Here are examples of rewrites for different types of prompts: # Examples (Few-Shot Learning) + 1. User Input: An animal with nine lives. + Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home + environment with light from the window filtering through curtains, creating a warm light and shadow effect. The + shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits + the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. + 2. User Input: Create an anime-style tourism flyer with a grassland theme. + Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped + rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her + left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs + covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To + the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The + grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies + the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is + a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, + and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. + 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer. + Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and + left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, + golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two + transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls + scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, + and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, + accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing + strong visual appeal. + 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident + posture. + Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her + shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long + eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She + has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. + Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a + black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and + metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible + knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a + relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute + focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots + on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional + and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are + dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. + The overall style is natural, elegant, and artistic. + 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should + include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. + Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage + precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark + soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with + green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, + stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden + light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches + and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under + a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into + the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the + orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a + realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the + apple's life cycle. + 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate + a four-color rainbow based on this rule. The color order from top to bottom is 3142. + Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as + purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the + number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the + bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast + with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. + The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing + the numerical information. The image is high definition, with accurate colors and a consistent style, offering + strong visual appeal. + 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a + Chinese garden. + Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with + traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the + stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo + forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a + realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the + stone tablet and the classical beauty of the garden. +# Output Format Please directly output the rewritten and optimized Prompt content. Do not include any explanatory +language or JSON formatting, and do not add opening or closing quotes yourself.""" + + +SYSTEM_PROMPT_ZH = """ +你是一名文生图模型的prompt +engineering专家。由于文生图模型对用户prompt的理解能力有限,你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 +具体要求如下: +1. 改写不能影响用户原始prompt里表达的任何信息,改写后的prompt应该使用连贯的自然语言表达,不要出现低信息量的冗余描述,尽可能保持改写后prompt长度精简。 +2. 请确保输入和输出的语言类型一致,中文输入中文输出,英文输入英文输出,改写后的token数量不要超过512个; +3. 改写后的描述应当进一步完善原始prompt中出现的主体特征、美学技巧,如打光、纹理等; +4. 如果原始prompt没有指定图片风格时,确保改写后的prompt使用真实摄影风格,如果用户指定了图片风格,则保留用户风格; +5. 当原始prompt需要推理才能明确用户意图时,根据世界知识进行适当逻辑推理,将模糊抽象描述转化为具体指向事物(例:将"最高的动物"转化为"一头长颈鹿")。 +6. 当原始prompt需要生成文字时,请使用双引号圈定文字部分,例:`"限时5折"`)。 +7. 当原始prompt需要生成网页、logo、ui、海报等文字场景时,且没有指定具体的文字内容时,需要推断出合适的文字内容,并使用双引号圈定,如用户输入:一个旅游宣传单,以草原为主题。应该改写成:一个旅游宣传单,图片标题为“草原”。 +8. 当原始prompt中存在否定词时,需要确保改写后的prompt不存在否定词,如没有船的湖边,改写后的prompt不能出现船这个词汇。 +9. 除非用户指定生成品牌logo,否则不要增加额外的品牌logo. +10. 除了用户明确要求书写的文字内容外,**禁止增加任何额外的文字内容**。 +以下是针对不同类型prompt改写的示例: + +# Examples (Few-Shot Learning) + 1. 用户输入: 九条命的动物。 + 改写输出: + 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 + 2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。 + 改写输出: + 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地,草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore + Nature's Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛围。 + 3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。 + 改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节 + 暖心回馈”居中排列,和红色粗体字“买1送1”。海报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的视觉吸引力。 + 4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。 + 改写输出: + 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 + 5. 用户输入:创作一系列图片,展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段:1. 播种,2. 幼苗生长,3. 植物成熟,4. 果实采摘。 + 改写输出:一个4宫格的精美插图,描绘苹果的生长过程,精确清晰地捕捉每个阶段。1.“播种”:特写镜头,一只手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中,土壤的纹理和种子光滑的表面清晰可见。背景是花园的柔焦画面,点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”:一棵幼小的苹果树苗破土而出,嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中,温暖的金光照亮了它。幼苗的纤细结构。3.“植物的成熟”:一棵成熟的苹果树,枝繁叶茂,挂满了嫩绿的叶子和正在萌发的小苹果。背景是一片生机勃勃的果园,湛蓝的天空下,斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象,背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和谐,展现了苹果生命周期的自然之美和发展过程。 + 6. 用户输入: 如果1代表红色,2代表绿色,3代表紫色,4代表黄色,请按照此规则生成四色彩虹。它的颜色顺序从上到下是3142 + 改写输出:图片由四个水平排列的彩色条纹组成,从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放置一个白色数字。最上方的紫色条纹上是数字“3”,其下方红色条纹上是数字“1”,再下方黄色条纹上是数字“4”,最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体,颜色为纯白色,与背景色形成鲜明对比,确保了良好的可读性。条纹的颜色饱和度高,且带有轻微的纹理感,整体排版简洁明了,视觉效果清晰,没有多余的装饰元素,强调了数字信息本身。图片整体清晰度高,色彩准确,风格一致,具有较强的视觉吸引力。 + 7. 用户输入:石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林 + 改写输出:一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。 +# 输出格式 请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。 +""" diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05117d35d3b4f1746610ec0d273d76f715215cfb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"] + _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] + _import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"] + _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] + _import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .modeling_latent_upsampler import LTXLatentUpsamplerModel + from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_condition import LTXConditionPipeline + from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline + from .pipeline_ltx_image2video import LTXImageToVideoPipeline + from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc1ccb62b8180be257637a36b662df4e6f6ad850 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/modeling_latent_upsampler.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/modeling_latent_upsampler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ca51880c486e34a6570c9a10bd893b03bb9da4a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/modeling_latent_upsampler.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8f9ccade3b7f73f33478eefd12bcc7c4011809d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_condition.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_condition.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c861aba7fcdbe9455607cd8637494485e32b33a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_condition.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_i2v_long_multi_prompt.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_i2v_long_multi_prompt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d691e8125c10cd71bd0b1a6b1d17fffa9cabc4d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_i2v_long_multi_prompt.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_image2video.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_image2video.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5771ef97f578e5573f8077e846b424259d4e278 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_image2video.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_latent_upsample.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_latent_upsample.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..844ff41040c442236df2ccd5de7e5004f52a5938 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_ltx_latent_upsample.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99f05c08d530a9c5f393c844fdaee962fee07328 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/modeling_latent_upsampler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/modeling_latent_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f579cf00dbe7a314aca365b503e2ac22756d1ce0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/modeling_latent_upsampler.py @@ -0,0 +1,186 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: int | None = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx.py new file mode 100644 index 0000000000000000000000000000000000000000..e2514c3bca249b3fd068e8f0b45a91a4d579d12e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx.py @@ -0,0 +1,852 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 3, + guidance_rescale: float = 0.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..539a28f56e672fb3f56b6d28a731e4a65962ac8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -0,0 +1,1287 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition + >>> from diffusers.utils import export_to_video, load_video, load_image + + >>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Load input image and video + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4" + ... ) + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg" + ... ) + + >>> # Create conditioning objects + >>> condition1 = LTXVideoCondition( + ... image=image, + ... frame_index=0, + ... ) + >>> condition2 = LTXVideoCondition( + ... video=video, + ... frame_index=80, + ... ) + + >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> # Generate video + >>> generator = torch.Generator("cuda").manual_seed(0) + >>> # Text-only conditioning is also supported without the need to pass `conditions` + >>> video = pipe( + ... conditions=[condition1, condition2], + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=161, + ... num_inference_steps=40, + ... generator=generator, + ... ).frames[0] + + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +@dataclass +class LTXVideoCondition: + """ + Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames. + + Attributes: + image (`PIL.Image.Image`): + The image to condition the video on. + video (`list[PIL.Image.Image]`): + The video to condition the video on. + frame_index (`int`): + The frame index at which the image or video will conditionally effect the video generation. + strength (`float`, defaults to `1.0`): + The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. + """ + + image: PIL.Image.Image | None = None + video: list[PIL.Image.Image] | None = None + frame_index: int = 0 + strength: float = 1.0 + + +# from LTX-Video/ltx_video/schedulers/rf.py +def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + if num_steps < 2: + return torch.tensor([1.0]) + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return torch.tensor(sigma_schedule[:-1]) + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for text/image/video-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + conditions, + image, + video, + frame_index, + strength, + denoise_strength, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if conditions is not None and (image is not None or video is not None): + raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") + + if conditions is None: + if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `image` and `frame_index` must be of the same length." + ) + elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength): + raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.") + elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index): + raise ValueError( + "If `conditions` is not provided, `video` and `frame_index` must be of the same length." + ) + elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength): + raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.") + + if denoise_strength < 0 or denoise_strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}") + + @staticmethod + def _prepare_video_ids( + batch_size: int, + num_frames: int, + height: int, + width: int, + patch_size: int = 1, + patch_size_t: int = 1, + device: torch.device = None, + ) -> torch.Tensor: + latent_sample_coords = torch.meshgrid( + torch.arange(0, num_frames, patch_size_t, device=device), + torch.arange(0, height, patch_size, device=device), + torch.arange(0, width, patch_size, device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width) + + return latent_coords + + @staticmethod + def _scale_video_ids( + video_ids: torch.Tensor, + scale_factor: int = 32, + scale_factor_t: int = 8, + frame_index: int = 0, + device: torch.device = None, + ) -> torch.Tensor: + scaled_latent_coords = ( + video_ids + * torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None] + ) + scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0) + scaled_latent_coords[:, 0] += frame_index + + return scaled_latent_coords + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int): + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + @staticmethod + def add_noise_to_image_conditioning_latents( + t: float, + init_latents: torch.Tensor, + latents: torch.Tensor, + noise_scale: float, + conditioning_mask: torch.Tensor, + generator, + eps=1e-6, + ): + """ + Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially + when conditioned on a single frame. + """ + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + # Add noise only to hard-conditioning latents (conditioning_mask = 1.0) + need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1) + noised_latents = init_latents + noise_scale * noise * (t**2) + latents = torch.where(need_to_noise, noised_latents, latents) + return latents + + def prepare_latents( + self, + conditions: list[torch.Tensor] | None = None, + condition_strength: list[float] | None = None, + condition_frame_index: list[int] | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + num_prefix_latent_frames: int = 2, + sigma: torch.Tensor | None = None, + latents: torch.Tensor | None = None, + generator: torch.Generator | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if latents is not None and sigma is not None: + if latents.shape != shape: + raise ValueError( + f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input." + ) + latents = latents.to(device=device, dtype=dtype) + sigma = sigma.to(device=device, dtype=dtype) + latents = sigma * noise + (1 - sigma) * latents + else: + latents = noise + + if len(conditions) > 0: + condition_latent_frames_mask = torch.zeros( + (batch_size, num_latent_frames), device=device, dtype=torch.float32 + ) + + extra_conditioning_latents = [] + extra_conditioning_video_ids = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 + for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): + condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) + condition_latents = self._normalize_latents( + condition_latents, self.vae.latents_mean, self.vae.latents_std + ).to(device, dtype=dtype) + + num_data_frames = data.size(2) + num_cond_frames = condition_latents.size(2) + + if frame_index == 0: + latents[:, :, :num_cond_frames] = torch.lerp( + latents[:, :, :num_cond_frames], condition_latents, strength + ) + condition_latent_frames_mask[:, :num_cond_frames] = strength + + else: + if num_data_frames > 1: + if num_cond_frames < num_prefix_latent_frames: + raise ValueError( + f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." + ) + + if num_cond_frames > num_prefix_latent_frames: + start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames + end_frame = start_frame + num_cond_frames - num_prefix_latent_frames + latents[:, :, start_frame:end_frame] = torch.lerp( + latents[:, :, start_frame:end_frame], + condition_latents[:, :, num_prefix_latent_frames:], + strength, + ) + condition_latent_frames_mask[:, start_frame:end_frame] = strength + condition_latents = condition_latents[:, :, :num_prefix_latent_frames] + + noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) + condition_latents = torch.lerp(noise, condition_latents, strength) + + condition_video_ids = self._prepare_video_ids( + batch_size, + condition_latents.size(2), + latent_height, + latent_width, + patch_size=self.transformer_spatial_patch_size, + patch_size_t=self.transformer_temporal_patch_size, + device=device, + ) + condition_video_ids = self._scale_video_ids( + condition_video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=frame_index, + device=device, + ) + condition_latents = self._pack_latents( + condition_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + condition_conditioning_mask = torch.full( + condition_latents.shape[:2], strength, device=device, dtype=dtype + ) + + extra_conditioning_latents.append(condition_latents) + extra_conditioning_video_ids.append(condition_video_ids) + extra_conditioning_mask.append(condition_conditioning_mask) + extra_conditioning_num_latents += condition_latents.size(1) + + video_ids = self._prepare_video_ids( + batch_size, + num_latent_frames, + latent_height, + latent_width, + patch_size_t=self.transformer_temporal_patch_size, + patch_size=self.transformer_spatial_patch_size, + device=device, + ) + if len(conditions) > 0: + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + else: + conditioning_mask, extra_conditioning_num_latents = None, 0 + video_ids = self._scale_video_ids( + video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=0, + device=device, + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + if len(conditions) > 0 and len(extra_conditioning_latents) > 0: + latents = torch.cat([*extra_conditioning_latents, latents], dim=1) + video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) + conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) + + return latents, conditioning_mask, video_ids, extra_conditioning_num_latents + + def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength): + num_steps = min(int(num_inference_steps * strength), num_inference_steps) + start_index = max(num_inference_steps - num_steps, 0) + sigmas = sigmas[start_index:] + timesteps = timesteps[start_index:] + return sigmas, timesteps, num_inference_steps - start_index + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + conditions: LTXVideoCondition | list[LTXVideoCondition] = None, + image: PipelineImageInput | list[PipelineImageInput] = None, + video: list[PipelineImageInput] = None, + frame_index: int | list[int] = 0, + strength: float | list[float] = 1.0, + denoise_strength: float = 1.0, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 3, + guidance_rescale: float = 0.0, + image_cond_noise_scale: float = 0.15, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + conditions (`list[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation.If not provided, conditions will be + created using `image`, `video`, `frame_index` and `strength`. + image (`PipelineImageInput` or `list[PipelineImageInput]`, *optional*): + The image or images to condition the video generation. If not provided, one has to pass `video` or + `conditions`. + video (`list[PipelineImageInput]`, *optional*): + The video to condition the video generation. If not provided, one has to pass `image` or `conditions`. + frame_index (`int` or `list[int]`, *optional*): + The frame index or frame indices at which the image or video will conditionally effect the video + generation. If not provided, one has to pass `conditions`. + strength (`float` or `list[float]`, *optional*): + The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`. + denoise_strength (`float`, defaults to `1.0`): + The strength of the noise added to the latents for editing. Higher strength leads to more noise added + to the latents, therefore leading to more differences between original video and generated video. This + is useful for video-to-video editing. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + conditions=conditions, + image=image, + video=video, + frame_index=frame_index, + strength=strength, + denoise_strength=denoise_strength, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if conditions is not None: + if not isinstance(conditions, list): + conditions = [conditions] + + strength = [condition.strength for condition in conditions] + frame_index = [condition.frame_index for condition in conditions] + image = [condition.image for condition in conditions] + video = [condition.video for condition in conditions] + elif image is not None or video is not None: + if not isinstance(image, list): + image = [image] + num_conditions = 1 + elif isinstance(image, list): + num_conditions = len(image) + if not isinstance(video, list): + video = [video] + num_conditions = 1 + elif isinstance(video, list): + num_conditions = len(video) + + if not isinstance(frame_index, list): + frame_index = [frame_index] * num_conditions + if not isinstance(strength, list): + strength = [strength] * num_conditions + + device = self._execution_device + vae_dtype = self.vae.dtype + + # 3. Prepare text embeddings & conditioning image/video + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + conditioning_tensors = [] + is_conditioning_image_or_video = image is not None or video is not None + if is_conditioning_image_or_video: + for condition_image, condition_video, condition_frame_index, condition_strength in zip( + image, video, frame_index, strength + ): + if condition_image is not None: + condition_tensor = ( + self.video_processor.preprocess(condition_image, height, width) + .unsqueeze(2) + .to(device, dtype=vae_dtype) + ) + elif condition_video is not None: + condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) + num_frames_input = condition_tensor.size(2) + num_frames_output = self.trim_conditioning_sequence( + condition_frame_index, num_frames_input, num_frames + ) + condition_tensor = condition_tensor[:, :, :num_frames_output] + condition_tensor = condition_tensor.to(device, dtype=vae_dtype) + else: + raise ValueError("Either `image` or `video` must be provided for conditioning.") + + if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: + raise ValueError( + f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " + f"but got {condition_tensor.size(2)} frames." + ) + conditioning_tensors.append(condition_tensor) + + # 4. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + + if timesteps is None: + sigmas = linear_quadratic_schedule(num_inference_steps) + timesteps = sigmas * 1000 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + ) + sigmas = self.scheduler.sigmas + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + latent_sigma = None + if denoise_strength < 1: + sigmas, timesteps, num_inference_steps = self.get_timesteps( + sigmas, timesteps, num_inference_steps, denoise_strength + ) + latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt) + + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents( + conditioning_tensors, + strength, + frame_index, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + sigma=latent_sigma, + latents=latents, + generator=generator, + device=device, + dtype=torch.float32, + ) + + video_coords = video_coords.float() + video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) + + init_latents = latents.clone() if is_conditioning_image_or_video else None + + if self.do_classifier_free_guidance: + video_coords = torch.cat([video_coords, video_coords], dim=0) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if image_cond_noise_scale > 0 and init_latents is not None: + # Add timestep-dependent noise to the hard-conditioning latents + # This helps with motion continuity, especially when conditioned on a single frame + latents = self.add_noise_to_image_conditioning_latents( + t / 1000.0, + init_latents, + latents, + image_cond_noise_scale, + conditioning_mask, + generator, + ) + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if is_conditioning_image_or_video: + conditioning_mask_model_input = ( + torch.cat([conditioning_mask, conditioning_mask]) + if self.do_classifier_free_guidance + else conditioning_mask + ) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() + if is_conditioning_image_or_video: + timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + denoised_latents = self.scheduler.step( + -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False + )[0] + if is_conditioning_image_or_video: + tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) + latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) + else: + latents = denoised_latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if is_conditioning_image_or_video: + latents = latents[:, extra_conditioning_num_latents:] + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + if output_type == "latent": + video = latents + else: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..838d5afc5c5adc37dc7500217fc35d758868b3fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py @@ -0,0 +1,1410 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler, LTXEulerAncestralRFScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXEulerAncestralRFScheduler, LTXI2VLongMultiPromptPipeline + + >>> pipe = LTXI2VLongMultiPromptPipeline.from_pretrained("LTX-Video-0.9.8-13B-distilled") + >>> # For ComfyUI parity, swap in the RF scheduler (keeps the original config). + >>> pipe.scheduler = LTXEulerAncestralRFScheduler.from_config(pipe.scheduler.config) + >>> pipe = pipe.to("cuda").to(dtype=torch.bfloat16) + >>> # Example A: get decoded frames (PIL) + >>> out = pipe( + ... prompt="a chimpanzee walks | a chimpanzee eats", + ... num_frames=161, + ... height=512, + ... width=704, + ... temporal_tile_size=80, + ... temporal_overlap=24, + ... output_type="pil", + ... return_dict=True, + ... ) + >>> frames = out.frames[0] # list of PIL.Image.Image + >>> # Example B: get latent video and decode later (saves VRAM during sampling) + >>> out_latent = pipe(prompt="a chimpanzee walking", output_type="latent", return_dict=True).frames + >>> frames = pipe.vae_decode_tiled(out_latent, output_type="pil")[0] + ``` +""" + + +def get_latent_coords( + latent_num_frames, latent_height, latent_width, batch_size, device, rope_interpolation_scale, latent_idx +): + """ + Compute latent patch top-left coordinates in (t, y, x) order. + + Args: + latent_num_frames: int. Number of latent frames (T_lat). + latent_height: int. Latent height (H_lat). + latent_width: int. Latent width (W_lat). + batch_size: int. Batch dimension (B). + device: torch.device for the resulting tensor. + rope_interpolation_scale: + tuple[int|float, int|float, int|float]. Scale per (t, y, x) latent step to pixel coords. + latent_idx: int | None. When not None, shifts the time coordinate to align segments: + - <= 0 uses step multiples of rope_interpolation_scale[0] + - > 0 starts at 1 then increments by rope_interpolation_scale[0] + + Returns: + Tensor of shape [B, 3, T_lat * H_lat * W_lat] containing top-left coordinates per latent patch, repeated for each + batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, 1, device=device), + torch.arange(0, latent_height, 1, device=device), + torch.arange(0, latent_width, 1, device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.flatten(2) + pixel_coords = latent_coords * torch.tensor(rope_interpolation_scale, device=latent_coords.device)[None, :, None] + if latent_idx is not None: + if latent_idx <= 0: + frame_idx = latent_idx * rope_interpolation_scale[0] + else: + frame_idx = 1 + (latent_idx - 1) * rope_interpolation_scale[0] + if frame_idx == 0: + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - rope_interpolation_scale[0]).clamp(min=0) + pixel_coords[:, 0] += frame_idx + return pixel_coords + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def adain_normalize_latents( + curr_latents: torch.Tensor, ref_latents: torch.Tensor | None, factor: float +) -> torch.Tensor: + """ + Optional AdaIN normalization: channel-wise mean/variance matching of curr_latents to ref_latents, controlled by + factor. + + Args: + curr_latents: Tensor [B, C, T, H, W]. Current window latents. + ref_latents: + Tensor | None [B, C, T_ref, H, W]. Reference latents (e.g., first window) used to compute target stats. + factor: float in [0, 1]. 0 keeps current stats; 1 matches reference stats. + + Returns: + Tensor with per-channel mean/std blended towards the reference. + """ + if ref_latents is None or factor is None or factor <= 0: + return curr_latents + + eps = torch.tensor(1e-6, device=curr_latents.device, dtype=curr_latents.dtype) + + # Compute per-channel means/stds for current and reference over (T, H, W) + mu_curr = curr_latents.mean(dim=(2, 3, 4), keepdim=True) + sigma_curr = curr_latents.std(dim=(2, 3, 4), keepdim=True) + + mu_ref = ref_latents.mean(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype) + sigma_ref = ref_latents.std(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype) + + # Blend target statistics + mu_blend = (1.0 - float(factor)) * mu_curr + float(factor) * mu_ref + sigma_blend = (1.0 - float(factor)) * sigma_curr + float(factor) * sigma_ref + sigma_blend = torch.clamp(sigma_blend, min=float(eps)) + + # Apply AdaIN + curr_norm = (curr_latents - mu_curr) / (sigma_curr + eps) + return curr_norm * sigma_blend + mu_blend + + +def split_into_temporal_windows( + latent_len: int, temporal_tile_size: int, temporal_overlap: int, compression: int +) -> list[tuple[int, int]]: + """ + Split latent frames into sliding windows. + + Args: + latent_len: int. Number of latent frames (T_lat). + temporal_tile_size: int. Window size in latent frames (> 0). + temporal_overlap: int. Overlap between windows in latent frames (>= 0). + compression: int. VAE temporal compression ratio (unused here; kept for parity). + + Returns: + list[tuple[int, int]]: inclusive-exclusive (start, end) indices per window. + """ + if temporal_tile_size <= 0: + raise ValueError("temporal_tile_size must be > 0") + stride = max(temporal_tile_size - temporal_overlap, 1) + windows = [] + start = 0 + while start < latent_len: + end = min(start + temporal_tile_size, latent_len) + windows.append((start, end)) + if end == latent_len: + break + start = start + stride + return windows + + +def linear_overlap_fuse(prev: torch.Tensor, new: torch.Tensor, overlap: int) -> torch.Tensor: + """ + Temporal linear crossfade between two latent clips over the overlap region. + + Args: + prev: Tensor [B, C, F, H, W]. Previous output segment. + new: Tensor [B, C, F, H, W]. New segment to be appended. + overlap: int. Number of frames to crossfade (overlap <= 1 concatenates without blend). + + Returns: + Tensor [B, C, F_prev + F_new - overlap, H, W] after crossfade at the seam. + """ + if overlap <= 1: + return torch.cat([prev, new], dim=2) + alpha = torch.linspace(1, 0, overlap + 2, device=prev.device, dtype=prev.dtype)[1:-1] + shape = [1] * prev.ndim + shape[2] = alpha.size(0) + alpha = alpha.reshape(shape) + blended = alpha * prev[:, :, -overlap:] + (1 - alpha) * new[:, :, :overlap] + return torch.cat([prev[:, :, :-overlap], blended, new[:, :, overlap:]], dim=2) + + +def inject_prev_tail_latents( + window_latents: torch.Tensor, + prev_tail_latents: torch.Tensor | None, + window_cond_mask_5d: torch.Tensor, + overlap_lat: int, + strength: float | None, + prev_overlap_len: int, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Inject the tail latents from the previous window at the beginning of the current window (first k frames), where k = + min(overlap_lat, T_curr, T_prev_tail). + + Args: + window_latents: Tensor [B, C, T, H, W]. Current window latents. + prev_tail_latents: Tensor | None [B, C, T_prev, H, W]. Tail segment from the previous window. + window_cond_mask_5d: Tensor [B, 1, T, H, W]. Per-token conditioning mask (1 = free, 0 = hard condition). + overlap_lat: int. Number of latent frames to inject from the previous tail. + strength: float | None in [0, 1]. Blend strength; 1.0 replaces, 0.0 keeps original. + prev_overlap_len: int. Accumulated overlap length so far (used for trimming later). + + Returns: + tuple[Tensor, Tensor, int]: (updated_window_latents, updated_cond_mask, updated_prev_overlap_len) + """ + if prev_tail_latents is None or overlap_lat <= 0 or strength is None or strength <= 0: + return window_latents, window_cond_mask_5d, prev_overlap_len + + # Expected shape: [B, C, T, H, W] + T = int(window_latents.shape[2]) + k = min(int(overlap_lat), T, int(prev_tail_latents.shape[2])) + if k <= 0: + return window_latents, window_cond_mask_5d, prev_overlap_len + + tail = prev_tail_latents[:, :, -k:] + mask = torch.full( + (window_cond_mask_5d.shape[0], 1, tail.shape[2], window_cond_mask_5d.shape[3], window_cond_mask_5d.shape[4]), + 1.0 - strength, + dtype=window_cond_mask_5d.dtype, + device=window_cond_mask_5d.device, + ) + + window_latents = torch.cat([window_latents, tail], dim=2) + window_cond_mask_5d = torch.cat([window_cond_mask_5d, mask], dim=2) + return window_latents, window_cond_mask_5d, prev_overlap_len + k + + +def build_video_coords_for_window( + latents: torch.Tensor, + overlap_len: int, + guiding_len: int, + negative_len: int, + rope_interpolation_scale: torch.Tensor, + frame_rate: int, +) -> torch.Tensor: + """ + Build video_coords: [B, 3, S] with order [t, y, x]. + + Args: + latents: Tensor [B, C, T, H, W]. Current window latents (before any trimming). + overlap_len: int. Number of frames from previous tail injected at the head. + guiding_len: int. Number of guidance frames appended at the head. + negative_len: int. Number of negative-index frames appended at the head (typically 1 or 0). + rope_interpolation_scale: tuple[int|float, int|float, int|float]. Scale for (t, y, x). + frame_rate: int. Used to convert time indices into seconds (t /= frame_rate). + + Returns: + Tensor [B, 3, T*H*W] of fractional pixel coordinates per latent patch. + """ + + b, c, f, h, w = latents.shape + pixel_coords = get_latent_coords(f, h, w, b, latents.device, rope_interpolation_scale, 0) + replace_corrds = [] + if overlap_len > 0: + replace_corrds.append(get_latent_coords(overlap_len, h, w, b, latents.device, rope_interpolation_scale, 0)) + if guiding_len > 0: + replace_corrds.append( + get_latent_coords(guiding_len, h, w, b, latents.device, rope_interpolation_scale, overlap_len) + ) + if negative_len > 0: + replace_corrds.append(get_latent_coords(negative_len, h, w, b, latents.device, rope_interpolation_scale, -1)) + if len(replace_corrds) > 0: + replace_corrds = torch.cat(replace_corrds, axis=2) + pixel_coords[:, :, -replace_corrds.shape[2] :] = replace_corrds + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + return fractional_coords + + +def parse_prompt_segments(prompt: str | list[str], prompt_segments: list[dict[str, Any]] | None) -> list[str]: + """ + Return a list of positive prompts per window index. + + Args: + prompt: str | list[str]. If str contains '|', parts are split by bars and trimmed. + prompt_segments: + list[dict], optional. Each dict with {"start_window", "end_window", "text"} overrides prompts per window. + + Returns: + list[str] containing the positive prompt for each window index. + """ + if prompt is None: + return [] + if prompt_segments: + max_w = 0 + for seg in prompt_segments: + max_w = max(max_w, int(seg.get("end_window", 0))) + texts = [""] * (max_w + 1) + for seg in prompt_segments: + s = int(seg.get("start_window", 0)) + e = int(seg.get("end_window", s)) + txt = seg.get("text", "") + for w in range(s, e + 1): + texts[w] = txt + # fill empty by last non-empty + last = "" + for i in range(len(texts)): + if texts[i] == "": + texts[i] = last + else: + last = texts[i] + return texts + + # bar-split mode + if isinstance(prompt, str): + parts = [p.strip() for p in prompt.split("|")] + else: + parts = prompt + parts = [p for p in parts if p is not None] + return parts + + +def batch_normalize(latents, reference, factor): + """ + Batch AdaIN-like normalization for latents in dict format (ComfyUI-compatible). + + Args: + latents: dict containing "samples" shaped [B, C, F, H, W] + reference: dict containing "samples" used to compute target stats + factor: float in [0, 1]; 0 = no change, 1 = full match to reference + Returns: + tuple[dict]: a single-element tuple with the updated latents dict. + """ + latents_copy = copy.deepcopy(latents) + t = latents_copy["samples"] # B x C x F x H x W + + for i in range(t.size(0)): # batch + for c in range(t.size(1)): # channel + r_sd, r_mean = torch.std_mean(reference["samples"][i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(t[i, c], dim=None) + + t[i, c] = ((t[i, c] - i_mean) / i_sd) * r_sd + r_mean + + latents_copy["samples"] = torch.lerp(latents["samples"], t, factor) + return (latents_copy,) + + +class LTXI2VLongMultiPromptPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Long-duration I2V (image-to-video) multi-prompt pipeline with ComfyUI parity. + + Key features: + - Temporal sliding-window sampling only (no spatial H/W sharding); autoregressive fusion across windows. + - Multi-prompt segmentation per window with smooth transitions at window heads. + - First-frame hard conditioning via per-token mask for I2V. + - VRAM control via temporal windowing and VAE tiled decoding. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`] or [`LTXEulerAncestralRFScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + if not isinstance(scheduler, LTXEulerAncestralRFScheduler): + logger.warning( + "For ComfyUI parity, `LTXI2VLongMultiPromptPipeline` is typically run with " + "`LTXEulerAncestralRFScheduler`. Got %s.", + scheduler.__class__.__name__, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + self._current_tile_T = None + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.guidance_rescale + def guidance_rescale(self): + return self._guidance_rescale + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.current_timestep + def current_timestep(self): + return self._current_timestep + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.attention_kwargs + def attention_kwargs(self): + return self._attention_kwargs + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.interrupt + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + generator: torch.Generator | None, + dtype: torch.dtype = torch.float32, + latents: torch.Tensor | None = None, + cond_latents: torch.Tensor | None = None, + cond_strength: float = 0.0, + negative_index_latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, int, int, int]: + """ + Prepare base latents and optionally inject first-frame conditioning latents. + + Returns: + latents, negative_index_latents, latent_num_frames, latent_height, latent_width + """ + if latents is None: + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = torch.zeros( + (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width), + device=device, + dtype=dtype, + ) + else: + latent_num_frames = latents.shape[2] + latent_height = latents.shape[3] + latent_width = latents.shape[4] + latents = latents.to(device=device, dtype=dtype) + + if cond_latents is not None and cond_strength > 0: + if negative_index_latents is None: + negative_index_latents = cond_latents + latents[:, :, :1, :, :] = cond_latents + + return latents, negative_index_latents, latent_num_frames, latent_height, latent_width + + # TODO: refactor this out + @torch.no_grad() + def vae_decode_tiled( + self, + latents: torch.Tensor, + decode_timestep: float | None = None, + decode_noise_scale: float | None = None, + horizontal_tiles: int = 4, + vertical_tiles: int = 4, + overlap: int = 3, + last_frame_fix: bool = True, + generator: torch.Generator | None = None, + output_type: str = "pt", + auto_denormalize: bool = True, + compute_dtype: torch.dtype = torch.float32, + enable_vae_tiling: bool = False, + ) -> torch.Tensor | np.ndarray | list[PIL.Image.Image]: + """ + VAE-based spatial tiled decoding (ComfyUI parity) implemented in Diffusers style. + - Linearly feather and blend overlapping tiles to avoid seams. + - Optional last_frame_fix: duplicate the last latent frame before decoding, then drop time_scale_factor frames + at the end. + - Supports timestep_conditioning and decode_noise_scale injection. + - By default, "normalized latents" (the denoising output) are de-normalized internally (auto_denormalize=True). + - Tile fusion is computed in compute_dtype (float32 by default) to reduce blur and color shifts. + + Args: + latents: [B, C_latent, F_latent, H_latent, W_latent] + decode_timestep: Optional decode timestep (effective only if VAE supports timestep_conditioning) + decode_noise_scale: + Optional decode noise interpolation (effective only if VAE supports timestep_conditioning) + horizontal_tiles, vertical_tiles: Number of tiles horizontally/vertically (>= 1) + overlap: Overlap in latent space (in latent pixels, >= 0) + last_frame_fix: Whether to enable the "repeat last frame" fix + generator: Random generator (used for decode_noise_scale noise) + output_type: "latent" | "pt" | "np" | "pil" + - "latent": return latents unchanged (useful for downstream processing) + - "pt": return tensor in VAE output space + - "np"/"pil": post-processed outputs via VideoProcessor.postprocess_video + auto_denormalize: If True, apply LTX de-normalization to `latents` internally (recommended) + compute_dtype: Precision used during tile fusion (float32 default; significantly reduces seam blur) + enable_vae_tiling: If True, delegate tiling to VAE's built-in `tiled_decode` (sets `vae.use_tiling`). + + Returns: + - If output_type="latent": returns input `latents` unchanged + - If output_type="pt": returns [B, C, F, H, W] (values roughly in [-1, 1]) + - If output_type="np"/"pil": returns post-processed outputs via postprocess_video + """ + if output_type == "latent": + return latents + if horizontal_tiles < 1 or vertical_tiles < 1: + raise ValueError("horizontal_tiles and vertical_tiles must be >= 1") + overlap = max(int(overlap), 0) + + # Device and precision + device = self._execution_device + latents = latents.to(device=device, dtype=compute_dtype) + + # De-normalize to VAE space (avoid color artifacts) + if auto_denormalize: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # dtype required for VAE forward pass + latents = latents.to(dtype=self.vae.dtype) + + # Temporal/spatial upscaling ratios (parity with ComfyUI's downscale_index_formula) + tsf = int(self.vae_temporal_compression_ratio) + sf = int(self.vae_spatial_compression_ratio) + + # Optional: last_frame_fix (repeat last latent frame) + if last_frame_fix: + latents = torch.cat([latents, latents[:, :, -1:].contiguous()], dim=2) + + b, c_lat, f_lat, h_lat, w_lat = latents.shape + f_out = 1 + (f_lat - 1) * tsf + h_out = h_lat * sf + w_out = w_lat * sf + + # timestep_conditioning + decode-time noise injection (aligned with pipeline) + if getattr(self.vae.config, "timestep_conditioning", False): + dt = float(decode_timestep) if decode_timestep is not None else 0.0 + vt = torch.tensor([dt], device=device, dtype=latents.dtype) + if decode_noise_scale is not None: + dns = torch.tensor([float(decode_noise_scale)], device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + latents = (1 - dns) * latents + dns * noise + else: + vt = None + + if enable_vae_tiling and hasattr(self.vae, "enable_tiling"): + self.vae.enable_tiling() + decoded = self.vae.decode(latents, vt, return_dict=False)[0] + if last_frame_fix: + decoded = decoded[:, :, :-tsf, :, :] + if output_type in ("np", "pil"): + return self.video_processor.postprocess_video(decoded, output_type=output_type) + return decoded + + # Compute base tile sizes (in latent space) + base_tile_h = (h_lat + (vertical_tiles - 1) * overlap) // vertical_tiles + base_tile_w = (w_lat + (horizontal_tiles - 1) * overlap) // horizontal_tiles + + output: torch.Tensor | None = None # [B, C_img, F, H, W], fused using compute_dtype + weights: torch.Tensor | None = None # [B, 1, F, H, W], fused using compute_dtype + + # Iterate tiles in latent space (no temporal tiling) + for v in range(vertical_tiles): + for h in range(horizontal_tiles): + h_start = h * (base_tile_w - overlap) + v_start = v * (base_tile_h - overlap) + + h_end = min(h_start + base_tile_w, w_lat) if h < horizontal_tiles - 1 else w_lat + v_end = min(v_start + base_tile_h, h_lat) if v < vertical_tiles - 1 else h_lat + + # Slice latent tile and decode + tile_latents = latents[:, :, :, v_start:v_end, h_start:h_end] + decoded_tile = self.vae.decode(tile_latents, vt, return_dict=False)[0] # [B, C, F, Ht, Wt] + # Cast to high precision to reduce blending blur + decoded_tile = decoded_tile.to(dtype=compute_dtype) + + # Initialize output buffers (compute_dtype) + if output is None: + output = torch.zeros( + (b, decoded_tile.shape[1], f_out, h_out, w_out), + device=decoded_tile.device, + dtype=compute_dtype, + ) + weights = torch.zeros( + (b, 1, f_out, h_out, w_out), + device=decoded_tile.device, + dtype=compute_dtype, + ) + + # Tile placement in output pixel space + out_h_start = v_start * sf + out_h_end = v_end * sf + out_w_start = h_start * sf + out_w_end = h_end * sf + + tile_out_h = out_h_end - out_h_start + tile_out_w = out_w_end - out_w_start + + # Linear feathering weights [B, 1, F, Ht, Wt] (compute_dtype) + tile_weights = torch.ones( + (b, 1, decoded_tile.shape[2], tile_out_h, tile_out_w), + device=decoded_tile.device, + dtype=compute_dtype, + ) + + overlap_out_h = overlap * sf + overlap_out_w = overlap * sf + + # Horizontal feathering: left/right overlaps + if overlap_out_w > 0: + if h > 0: + h_blend = torch.linspace( + 0, 1, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, :, :overlap_out_w] *= h_blend.view(1, 1, 1, 1, -1) + if h < horizontal_tiles - 1: + h_blend = torch.linspace( + 1, 0, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend.view(1, 1, 1, 1, -1) + + # Vertical feathering: top/bottom overlaps + if overlap_out_h > 0: + if v > 0: + v_blend = torch.linspace( + 0, 1, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1) + if v < vertical_tiles - 1: + v_blend = torch.linspace( + 1, 0, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1) + + # Accumulate blended tile + output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += decoded_tile * tile_weights + weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights + + # Normalize, then clamp to [-1, 1] in compute_dtype to avoid color artifacts + output = output / (weights + 1e-8) + output = output.clamp(-1.0, 1.0) + output = output.to(dtype=self.vae.dtype) + + # Optional: drop the last tsf frames after last_frame_fix + if last_frame_fix: + output = output[:, :, :-tsf, :, :] + + if output_type in ("np", "pil"): + return self.video_processor.postprocess_video(output, output_type=output_type) + return output + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + prompt_segments: list[dict[str, Any]] | None = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: float = 25, + guidance_scale: float = 1.0, + guidance_rescale: float = 0.0, + num_inference_steps: int | None = 8, + sigmas: list[float, torch.Tensor] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + seed: int | None = 0, + cond_image: "PIL.Image.Image" | torch.Tensor | None = None, + cond_strength: float = 0.5, + latents: torch.Tensor | None = None, + temporal_tile_size: int = 80, + temporal_overlap: int = 24, + temporal_overlap_cond_strength: float = 0.5, + adain_factor: float = 0.25, + guidance_latents: torch.Tensor | None = None, + guiding_strength: float = 1.0, + negative_index_latents: torch.Tensor | None = None, + negative_index_strength: float = 1.0, + skip_steps_sigma_threshold: float | None = 1, + decode_timestep: float | None = 0.05, + decode_noise_scale: float | None = 0.025, + decode_horizontal_tiles: int = 4, + decode_vertical_tiles: int = 4, + decode_overlap: int = 3, + output_type: str | None = "latent", # "latent" | "pt" | "np" | "pil" + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Generate an image-to-video sequence via temporal sliding windows and multi-prompt scheduling. + + Args: + prompt (`str` or `list[str]`, *optional*): + Positive text prompt(s) per window. If a single string contains '|', parts are split by bars. + negative_prompt (`str` or `list[str]`, *optional*): + Negative prompt(s) to suppress undesired content. + prompt_segments (`list[dict]`, *optional*): + Segment mapping with {"start_window", "end_window", "text"} to override prompts per window. + height (`int`, defaults to `512`): + Output image height in pixels; must be divisible by 32. + width (`int`, defaults to `704`): + Output image width in pixels; must be divisible by 32. + num_frames (`int`, defaults to `161`): + Number of output frames (in decoded pixel space). + frame_rate (`float`, defaults to `25`): + Frames-per-second; used to normalize temporal coordinates in `video_coords`. + guidance_scale (`float`, defaults to `1.0`): + CFG scale; values > 1 enable classifier-free guidance. + guidance_rescale (`float`, defaults to `0.0`): + Optional rescale to mitigate overexposure under CFG (see `rescale_noise_cfg`). + num_inference_steps (`int`, *optional*, defaults to `8`): + Denoising steps per window. Ignored if `sigmas` is provided. + sigmas (`list[float]` or `torch.Tensor`, *optional*): + Explicit sigma schedule per window; if set, overrides `num_inference_steps`. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Controls stochasticity; list accepted but first element is used (batch=1). + seed (`int`, *optional*, defaults to `0`): + If provided, seeds the shared generator for global latents and derives a window-local generator with + `seed + w_start` per temporal window. + cond_image (`PIL.Image.Image` or `torch.Tensor`, *optional*): + Conditioning image; fixes frame 0 via per-token mask when `cond_strength > 0`. + cond_strength (`float`, defaults to `0.5`): + Strength of first-frame hard conditioning (smaller cond_mask ⇒ stronger preservation). + latents (`torch.Tensor`, *optional*): + Initial latents [B, C_lat, F_lat, H_lat, W_lat]; if None, sampled with `randn_tensor`. + temporal_tile_size (`int`, defaults to `80`): + Temporal window size (in decoded frames); internally scaled by VAE temporal compression. + temporal_overlap (`int`, defaults to `24`): + Overlap between consecutive windows (in decoded frames); internally scaled by compression. + temporal_overlap_cond_strength (`float`, defaults to `0.5`): + Strength for injecting previous window tail latents at new window head. + adain_factor (`float`, defaults to `0.25`): + AdaIN normalization strength for cross-window consistency (0 disables). + guidance_latents (`torch.Tensor`, *optional*): + Reference latents injected at window head; length trimmed by overlap for subsequent windows. + guiding_strength (`float`, defaults to `1.0`): + Injection strength for `guidance_latents`. + negative_index_latents (`torch.Tensor`, *optional*): + A single-frame latent appended at window head for "negative index" semantics. + negative_index_strength (`float`, defaults to `1.0`): + Injection strength for `negative_index_latents`. + skip_steps_sigma_threshold (`float`, *optional*, defaults to `1`): + Skip steps whose sigma exceeds this threshold. + decode_timestep (`float`, *optional*, defaults to `0.05`): + Decode-time timestep (if VAE supports timestep_conditioning). + decode_noise_scale (`float`, *optional*, defaults to `0.025`): + Decode-time noise mix scale (if VAE supports timestep_conditioning). + decode_horizontal_tiles (`int`, defaults to `4`): + Number of horizontal tiles during VAE decoding. + decode_vertical_tiles (`int`, defaults to `4`): + Number of vertical tiles during VAE decoding. + decode_overlap (`int`, defaults to `3`): + Overlap (in latent pixels) between tiles during VAE decoding. + output_type (`str`, *optional*, defaults to `"latent"`): + The output format of the generated video. Choose between "latent", "pt", "np", or "pil". If "latent", + returns latents without decoding. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Extra attention parameters forwarded to the transformer. + callback_on_step_end (`PipelineCallback` or `MultiPipelineCallbacks`, *optional*): + Per-step callback hook. + callback_on_step_end_tensor_inputs (`list[str]`, defaults to `["latents"]`): + Keys from locals() to pass into the callback. + max_sequence_length (`int`, defaults to `128`): + Tokenizer max length for prompt encoding. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated frames. The output format depends on + `output_type`: + - "latent"/"pt": `torch.Tensor` [B, C, F, H, W]; "latent" is in normalized latent space, "pt" is VAE + output space. + - "np": `np.ndarray` post-processed. + - "pil": `list[PIL.Image.Image]` list of PIL images. + + Shapes: + Latent sizes (when auto-generated): + - F_lat = (num_frames - 1) // vae_temporal_compression_ratio + 1 + - H_lat = height // vae_spatial_compression_ratio + - W_lat = width // vae_spatial_compression_ratio + + Notes: + - Seeding: when `seed` is provided, each temporal window uses a local generator seeded with `seed + + w_start`, while the shared generator is seeded once for global latents if no generator is passed; + otherwise the passed-in generator is reused. + - CFG: unified `noise_pred = uncond + w * (text - uncond)` with optional `guidance_rescale`. + - Memory: denoising performs full-frame predictions (no spatial tiling); decoding can be tiled to avoid + OOM. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Input validation: height/width must be divisible by 32 + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 1. Device & generator + device = self._execution_device + # Normalize generator input: accept list but use the first (batch_size=1) + if isinstance(generator, list): + generator = generator[0] + if seed is not None and generator is None: + generator = torch.Generator(device=device).manual_seed(seed) + + # 2. Optional i2v first frame conditioning: encode cond_image and inject at frame 0 via prepare_latents + cond_latents = None + if cond_image is not None and cond_strength > 0: + img = self.video_processor.preprocess(cond_image, height=height, width=width) + img = img.to(device=device, dtype=self.vae.dtype) + enc = self.vae.encode(img.unsqueeze(2)) # [B, C, 1, h, w] + cond_latents = enc.latent_dist.mode() if hasattr(enc, "latent_dist") else enc.latents + cond_latents = cond_latents.to(torch.float32) + cond_latents = self._normalize_latents( + cond_latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + # 3. Global initial latents [B,C,F,H,W], optionally seeded/conditioned + latents, negative_index_latents, latent_num_frames, latent_height, latent_width = self.prepare_latents( + batch_size=1, + num_channels_latents=self.transformer.config.in_channels, + height=height, + width=width, + num_frames=num_frames, + device=device, + generator=generator, + dtype=torch.float32, + latents=latents, + cond_latents=cond_latents, + cond_strength=cond_strength, + negative_index_latents=negative_index_latents, + ) + if guidance_latents is not None: + guidance_latents = guidance_latents.to(device=device, dtype=torch.float32) + if latents.shape[2] != guidance_latents.shape[2]: + raise ValueError("The number of frames in `latents` and `guidance_latents` must be the same") + + # 4. Sliding windows in latent frames + tile_size_lat = max(1, temporal_tile_size // self.vae_temporal_compression_ratio) + overlap_lat = max(0, temporal_overlap // self.vae_temporal_compression_ratio) + windows = split_into_temporal_windows( + latent_num_frames, tile_size_lat, overlap_lat, self.vae_temporal_compression_ratio + ) + + # 5. Multi-prompt segments parsing + segment_texts = parse_prompt_segments(prompt, prompt_segments) + + out_latents = None + first_window_latents = None + + # 6. Process each temporal window + for w_idx, (w_start, w_end) in enumerate(windows): + if self.interrupt: + break + + # 6.1 Encode prompt embeddings per window segment + seg_index = min(w_idx, len(segment_texts) - 1) if segment_texts else 0 + pos_text = segment_texts[seg_index] if segment_texts else (prompt if isinstance(prompt, str) else "") + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=[pos_text], + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=1, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + max_sequence_length=max_sequence_length, + device=device, + dtype=None, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 6.2 Window-level timesteps reset: fresh sampling for each temporal window + if sigmas is not None: + s = torch.tensor(sigmas, dtype=torch.float32) if not isinstance(sigmas, torch.Tensor) else sigmas + self.scheduler.set_timesteps(sigmas=s, device=device) + self._num_timesteps = len(sigmas) + else: + self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device) + self._num_timesteps = num_inference_steps + + # 6.3 Extract window latents [B,C,T,H,W] + window_latents = latents[:, :, w_start:w_end] + window_guidance_latents = guidance_latents[:, :, w_start:w_end] if guidance_latents is not None else None + window_T = window_latents.shape[2] + + # 6.4 Build per-window cond mask and inject previous tails / reference + window_cond_mask_5d = torch.ones( + (1, 1, window_T, latent_height, latent_width), device=device, dtype=torch.float32 + ) + self._current_tile_T = window_T + prev_overlap_len = 0 + # Inter-window tail latent injection (Extend) + if w_idx > 0 and overlap_lat > 0 and out_latents is not None: + k = min(overlap_lat, out_latents.shape[2]) + prev_tail = out_latents[:, :, -k:] + window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents( + window_latents, + prev_tail, + window_cond_mask_5d, + overlap_lat, + temporal_overlap_cond_strength, + prev_overlap_len, + ) + # Reference/negative-index latent injection (append 1 frame at window head; controlled by negative_index_strength) + if window_guidance_latents is not None: + guiding_len = ( + window_guidance_latents.shape[2] if w_idx == 0 else window_guidance_latents.shape[2] - overlap_lat + ) + window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents( + window_latents, + window_guidance_latents[:, :, -guiding_len:], + window_cond_mask_5d, + guiding_len, + guiding_strength, + prev_overlap_len, + ) + else: + guiding_len = 0 + window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents( + window_latents, + negative_index_latents, + window_cond_mask_5d, + 1, + negative_index_strength, + prev_overlap_len, + ) + if w_idx == 0 and cond_image is not None and cond_strength > 0: + # First-frame I2V: smaller mask means stronger preservation of the original latent + window_cond_mask_5d[:, :, 0] = 1.0 - cond_strength + + # Update effective window latent sizes (consider injections on T/H/W) + w_B, w_C, w_T_eff, w_H_eff, w_W_eff = window_latents.shape + p = self.transformer_spatial_patch_size + pt = self.transformer_temporal_patch_size + + # 6.5 Pack full-window latents/masks once + # Seeding policy: derive a window-local generator to decouple RNG across windows + if seed is not None: + tile_seed = int(seed) + int(w_start) + local_gen = torch.Generator(device=device).manual_seed(tile_seed) + else: + local_gen = generator + # randn*mask + (1-mask)*latents implements hard-condition initialization + init_rand = randn_tensor(window_latents.shape, generator=local_gen, device=device, dtype=torch.float32) + mixed_latents = init_rand * window_cond_mask_5d + (1 - window_cond_mask_5d) * window_latents + window_latents_packed = self._pack_latents( + window_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + latents_packed = self._pack_latents( + mixed_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + cond_mask_tokens = self._pack_latents( + window_cond_mask_5d, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if self.do_classifier_free_guidance: + cond_mask = torch.cat([cond_mask_tokens, cond_mask_tokens], dim=0) + else: + cond_mask = cond_mask_tokens + + # 6.6 Denoising loop per full window (no spatial tiling) + sigmas_current = self.scheduler.sigmas.to(device=latents_packed.device) + if sigmas_current.shape[0] >= 2: + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[:-1])): + if self.interrupt: + break + # Skip semantics: if sigma exceeds threshold, skip this step (do not call scheduler.step) + sigma_val = float(sigmas_current[i].item()) + if skip_steps_sigma_threshold is not None and float(skip_steps_sigma_threshold) > 0.0: + if sigma_val > float(skip_steps_sigma_threshold): + continue + + self._current_timestep = t + + # Model input (stack 2 copies under CFG) + latent_model_input = ( + torch.cat([latents_packed] * 2) if self.do_classifier_free_guidance else latents_packed + ) + # Broadcast timesteps, combine with per-token cond mask (I2V at window head) + timestep = t.expand(latent_model_input.shape[0]) + if cond_mask is not None: + # Broadcast timestep to per-token mask under CFG: [B] -> [B, S, 1] + timestep = timestep[:, None, None] * cond_mask + + # Micro-conditions: only provide video_coords (num_frames/height/width set to 1) + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Inpainting pre-blend (ComfyUI parity: KSamplerX0Inpaint:400) + if cond_mask_tokens is not None: + latents_packed = latents_packed * cond_mask_tokens + window_latents_packed * ( + 1.0 - cond_mask_tokens + ) + + # Negative-index/overlap lengths (for segmenting time coordinates; RoPE-compatible) + k_negative_count = ( + 1 if (negative_index_latents is not None and float(negative_index_strength) > 0.0) else 0 + ) + k_overlap_count = overlap_lat if (w_idx > 0 and overlap_lat > 0) else 0 + video_coords = build_video_coords_for_window( + latents=window_latents, + overlap_len=int(k_overlap_count), + guiding_len=int(guiding_len), + negative_len=int(k_negative_count), + rope_interpolation_scale=rope_interpolation_scale, + frame_rate=frame_rate, + ) + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input.to(dtype=self.transformer.dtype), + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=1, + height=1, + width=1, + rope_interpolation_scale=rope_interpolation_scale, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # Unified CFG + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.guidance_rescale > 0: + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # Use global timestep for scheduling, but apply suppressive blending with hard-condition tokens (e.g., first frame) after step to avoid brightness/flicker due to time misalignment + latents_packed = self.scheduler.step( + noise_pred, t, latents_packed, generator=local_gen, return_dict=False + )[0] + # Inpainting post-blend (ComfyUI parity: restore hard-conditioned regions after update) + if cond_mask_tokens is not None: + latents_packed = latents_packed * cond_mask_tokens + window_latents_packed * ( + 1.0 - cond_mask_tokens + ) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents_packed = callback_outputs.pop("latents", latents_packed) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + else: + # Not enough sigmas to perform a valid step; skip this window safely. + pass + + # 6.7 Unpack back to [B,C,T,H,W] once + window_out = self._unpack_latents( + latents_packed, + w_T_eff, + w_H_eff, + w_W_eff, + p, + pt, + ) + if prev_overlap_len > 0: + window_out = window_out[:, :, :-prev_overlap_len] + + # 6.8 Overlap handling and fusion + if out_latents is None: + # First window: keep all latent frames and cache as AdaIN reference + out_latents = window_out + first_window_latents = out_latents + else: + window_out = window_out[:, :, 1:] # Drop the first frame of the new window + if adain_factor > 0 and first_window_latents is not None: + window_out = adain_normalize_latents(window_out, first_window_latents, adain_factor) + overlap_len = max(overlap_lat - 1, 1) + prev_tail_chunk = out_latents[:, :, -window_out.shape[2] :] + fused = linear_overlap_fuse(prev_tail_chunk, window_out, overlap_len) + out_latents = torch.cat([out_latents[:, :, : -window_out.shape[2]], fused], dim=2) + + # 7. Decode or return latent + if output_type == "latent": + video = out_latents + else: + # Decode via tiling to avoid OOM from full-frame decoding; latents are already de-normalized, so keep auto_denormalize disabled + video = self.vae_decode_tiled( + out_latents, + decode_timestep=decode_timestep, + decode_noise_scale=decode_noise_scale, + horizontal_tiles=int(decode_horizontal_tiles), + vertical_tiles=int(decode_vertical_tiles), + overlap=int(decode_overlap), + generator=generator, + output_type=output_type, # Keep type consistent; postprocess is applied afterwards + ) + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..497f505c4dd828cdfcdcd2b98c75448b415c1a48 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -0,0 +1,949 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 128, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def prepare_latents( + self, + image: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 3, + guidance_rescale: float = 0.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = self._unpack_latents( + noise_pred, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..17d4e1d8fc57a178ab40cffee94a60ae495e4298 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py @@ -0,0 +1,339 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTXVideo +from ...utils import deprecate, get_logger +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .modeling_latent_upsampler import LTXLatentUpsamplerModel +from .pipeline_output import LTXPipelineOutput + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTXLatentUpsamplePipeline(DiffusionPipeline): + model_cpu_offload_seq = "" + + def __init__( + self, + vae: AutoencoderKLLTXVideo, + latent_upsampler: LTXLatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: torch.Tensor | None = None, + batch_size: int = 1, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually + smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially + when controlling dynamic behavior with a `compression` factor. + + Args: + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def check_inputs(self, video, height, width, latents, tone_map_compression_ratio): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + if not (0 <= tone_map_compression_ratio <= 1): + raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") + + @torch.no_grad() + def __call__( + self, + video: list[PipelineImageInput] | None = None, + height: int = 512, + width: int = 704, + latents: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + adain_factor: float = 0.0, + tone_map_compression_ratio: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ): + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + tone_map_compression_ratio=tone_map_compression_ratio, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + if tone_map_compression_ratio > 0.0: + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + + if output_type == "latent": + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..f5cb34aa508dc8c6e570f310cbd7615d4cd19154 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTXPipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a408d5c546f9ca1ae2f4d00ec00b4a136d6800 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/__init__.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["connectors"] = ["LTX2TextConnectors"] + _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] + _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"] + _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] + _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] + _import_structure["vocoder"] = ["LTX2Vocoder"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .connectors import LTX2TextConnectors + from .latent_upsampler import LTX2LatentUpsamplerModel + from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_condition import LTX2ConditionPipeline + from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline + from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline + from .vocoder import LTX2Vocoder + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/connectors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/connectors.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2a81a9dc2c7ba7fcb7d3205b4e91a2c53c49ec --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/connectors.py @@ -0,0 +1,324 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.attention import FeedForward +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + def forward( + self, + batch_size: int, + pos: int, + device: str | torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + rope_type: str = "interleaved", + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + rope_type=rope_type, + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + rotary_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int, + text_proj_in_factor: int, + video_connector_num_attention_heads: int, + video_connector_attention_head_dim: int, + video_connector_num_layers: int, + video_connector_num_learnable_registers: int | None, + audio_connector_num_attention_heads: int, + audio_connector_attention_head_dim: int, + audio_connector_num_layers: int, + audio_connector_num_learnable_registers: int | None, + connector_rope_base_seq_len: int, + rope_theta: float, + rope_double_precision: bool, + causal_temporal_positioning: bool, + rope_type: str = "interleaved", + ): + super().__init__() + self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + + def forward( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/export_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/export_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0287506b8dbda3d59a61fae5e581fbf9dd763a2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/export_utils.py @@ -0,0 +1,191 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterator +from fractions import Fraction +from itertools import chain + +import numpy as np +import PIL.Image +import torch +from tqdm import tqdm + +from ...utils import get_logger, is_av_available + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +_CAN_USE_AV = is_av_available() +if _CAN_USE_AV: + import av +else: + raise ImportError( + "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`" + ) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + samples: torch.Tensor, + audio_sample_rate: int, +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def encode_video( + video: list[PIL.Image.Image] | np.ndarray | torch.Tensor | Iterator[torch.Tensor], + fps: int, + audio: torch.Tensor, + audio_sample_rate: int, + output_path: str, + video_chunks_number: int = 1, +) -> None: + """ + Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo: + https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182 + + Args: + video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the + input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines + usually return with `output_type="np"`). + fps (`int`) + The frames per second (FPS) of the encoded video. + audio (`torch.Tensor`, *optional*): + An audio waveform of shape [audio_channels, samples]. + audio_sample_rate: (`int`, *optional*): + The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz). + output_path (`str`): + The path to save the encoded video to. + video_chunks_number (`int`, *optional*, defaults to `1`): + The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The + number of chunks to use often depends on the tiling config for the video VAE. + """ + if isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + # Pipeline output_type="pil"; assumes each image is in "RGB" mode + video_frames = [np.array(frame) for frame in video] + video = np.stack(video_frames, axis=0) + video = torch.from_numpy(video) + elif isinstance(video, np.ndarray): + # Pipeline output_type="np" + is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video)) + if np.all(is_denormalized): + video = (video * 255).round().astype("uint8") + else: + logger.warning( + "Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel " + "values in [0, ..., 255] and will be used as is." + ) + video = torch.from_numpy(video) + + if isinstance(video, torch.Tensor): + # Split into video_chunks_number along the frame dimension + video = torch.tensor_split(video, video_chunks_number, dim=0) + video = iter(video) + + first_chunk = next(video) + + _, height, width, _ = first_chunk.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"): + video_chunk_cpu = video_chunk.to("cpu").numpy() + for frame_array in video_chunk_cpu: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/latent_upsampler.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/latent_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c589a70ab6a3e1c22033ce86f6450447e5ec2f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/latent_upsampler.py @@ -0,0 +1,284 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +RATIONAL_RESAMPLER_SCALE_MAPPING = { + 0.75: (3, 4), + 1.5: (3, 2), + 2.0: (2, 1), + 4.0: (4, 1), +} + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: int | None = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. + Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + + if dims not in (2, 3): + raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") + if kernel_size < 3 or kernel_size % 2 != 1: + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") + + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + c = x.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + else: + # dims == 3: apply per-frame on H,W + b, c, f, _, _ = x.shape + x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + + h2, w2 = x.shape[-2:] + x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class SpatialRationalResampler(torch.nn.Module): + """ + Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample + by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the + input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the + (integer) denominator. + """ + + def __init__(self, mid_channels: int = 1024, scale: float = 2.0): + super().__init__() + self.scale = float(scale) + num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) + if num_denom is None: + raise ValueError( + f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" + ) + self.num, self.den = num_denom + + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Expected x shape: [B * F, C, H, W] + # b, _, f, h, w = x.shape + # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + rational_spatial_scale: float | None = 2.0, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_spatial_scale is not None: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2.py new file mode 100644 index 0000000000000000000000000000000000000000..037840360137ec9b360d516c42b62a413a78146d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -0,0 +1,1226 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..4c451330f4397396b5af40ec715651793adf3bbe --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_condition.py @@ -0,0 +1,1474 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ConditionPipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ConditionPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> first_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png" + ... ) + >>> first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0) + >>> last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0) + >>> conditions = [first_cond, last_cond] + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static" + + >>> frame_rate = 24.0 + >>> video = pipe( + ... conditions=conditions, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +@dataclass +class LTX2VideoCondition: + """ + Defines a single frame-conditioning item for LTX-2 Video - a single frame or a sequence of frames. + + Attributes: + frames (`PIL.Image.Image` or `List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): + The image (or video) to condition the video on. Accepts any type that can be handled by + VideoProcessor.preprocess_video. + index (`int`, defaults to `0`): + The index at which the image or video will conditionally affect the video generation. + strength (`float`, defaults to `1.0`): + The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied. + """ + + frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor + index: int = 0 + strength: float = 1.0 + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for video generation which allows image conditions to be inserted at arbitary parts of the video. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + latents=None, + audio_latents=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if latents is not None and latents.ndim != 5: + raise ValueError( + f"Only unpacked (5D) video latents of shape `[batch_size, latent_channels, latent_frames," + f" latent_height, latent_width] are supported, but got {latents.ndim} dims. If you have packed (3D)" + f" latents, please unpack them (e.g. using the `_unpack_latents` method)." + ) + if audio_latents is not None and audio_latents.ndim != 4: + raise ValueError( + f"Only unpacked (4D) audio latents of shape `[batch_size, num_channels, audio_length, mel_bins] are" + f" supported, but got {latents.ndim} dims. If you have packed (3D) latents, please unpack them (e.g." + f" using the `_unpack_audio_latents` method)." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + # Copied from diffusers.pipelines.ltx.pipeline_ltx_condition.LTXConditionPipeline.trim_conditioning_sequence + def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int) -> int: + """ + Trim a conditioning sequence to the allowed number of frames. + + Args: + start_frame (int): The target frame number of the first frame in the sequence. + sequence_num_frames (int): The number of frames in the sequence. + target_num_frames (int): The target number of frames in the generated video. + Returns: + int: updated sequence length + """ + scale_factor = self.vae_temporal_compression_ratio + num_frames = min(sequence_num_frames, target_num_frames - start_frame) + # Trim down to a multiple of temporal_scale_factor frames plus 1 + num_frames = (num_frames - 1) // scale_factor * scale_factor + 1 + return num_frames + + def preprocess_conditions( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + device: torch.device | None = None, + ) -> tuple[list[torch.Tensor], list[float], list[int]]: + """ + Preprocesses the condition images/videos to torch tensors. + + Args: + conditions (`LTX2VideoCondition` or `List[LTX2VideoCondition]`, *optional*, defaults to `None`): + A list of image/video condition instances. + height (`int`, *optional*, defaults to `512`): + The desired height in pixels. + width (`int`, *optional*, defaults to `768`): + The desired width in pixels. + num_frames (`int`, *optional*, defaults to `121`): + The desired number of frames in the generated video. + device (`torch.device`, *optional*, defaults to `None`): + The device on which to put the preprocessed image/video tensors. + + Returns: + `Tuple[List[torch.Tensor], List[float], List[int]]`: + Returns a 3-tuple of lists of length `len(conditions)` as follows: + 1. The first list is a list of preprocessed video tensors of shape [batch_size=1, num_channels, + num_frames, height, width]. + 2. The second list is a list of conditioning strengths. + 3. The third list is a list of indices in latent space to insert the corresponding condition. + """ + conditioning_frames, conditioning_strengths, conditioning_indices = [], [], [] + + if conditions is None: + conditions = [] + if isinstance(conditions, LTX2VideoCondition): + conditions = [conditions] + + frame_scale_factor = self.vae_temporal_compression_ratio + latent_num_frames = (num_frames - 1) // frame_scale_factor + 1 + for i, condition in enumerate(conditions): + if isinstance(condition.frames, PIL.Image.Image): + # Single image, convert to List[PIL.Image.Image] + video_like_cond = [condition.frames] + elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3: + # Image-like ndarray of shape (H, W, C), insert frame dim in first axis + video_like_cond = np.expand_dims(condition.frames, axis=0) + elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3: + # Image-like tensor of shape (C, H, W), insert frame dim in first dim + video_like_cond = condition.frames.unsqueeze(0) + else: + # Treat all other as videos. Note that this means 4D ndarrays and tensors will be treated as videos of + # shape (F, H, W, C) and (F, C, H, W), respectively. + video_like_cond = condition.frames + condition_pixels = self.video_processor.preprocess_video( + video_like_cond, height, width, resize_mode="crop" + ) + + # Interpret the index as a latent index, following the original LTX-2 code. + latent_start_idx = condition.index + # Support negative latent indices (e.g. -1 for the last latent index) + if latent_start_idx < 0: + # latent_start_idx will be positive because latent_num_frames is positive + latent_start_idx = latent_start_idx % latent_num_frames + if latent_start_idx >= latent_num_frames: + logger.warning( + f"The starting latent index {latent_start_idx} of condition {i} is too big for the specified number" + f" of latent frames {latent_num_frames}. This condition will be skipped." + ) + continue + + cond_num_frames = condition_pixels.size(2) + start_idx = max((latent_start_idx - 1) * frame_scale_factor + 1, 0) + truncated_cond_frames = self.trim_conditioning_sequence(start_idx, cond_num_frames, num_frames) + condition_pixels = condition_pixels[:, :, :truncated_cond_frames] + + conditioning_frames.append(condition_pixels.to(dtype=self.vae.dtype, device=device)) + conditioning_strengths.append(condition.strength) + conditioning_indices.append(latent_start_idx) + + return conditioning_frames, conditioning_strengths, conditioning_indices + + def apply_visual_conditioning( + self, + latents: torch.Tensor, + conditioning_mask: torch.Tensor, + condition_latents: list[torch.Tensor], + condition_strengths: list[float], + condition_indices: list[int], + latent_height: int, + latent_width: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies visual conditioning frames to an initial latent. + + Args: + latents (`torch.Tensor`): + Initial packed (patchified) latents of shape [batch_size, patch_seq_len, hidden_dim]. + conditioning_mask (`torch.Tensor`, *optional*): + Initial packed (patchified) conditioning mask of shape [batch_size, patch_seq_len, 1] with values in + [0, 1] where 0 means that the denoising model output will be fully used and 1 means that the condition + will be fully used (with intermediate values specifying a blend of the denoised and latent values). + + Returns: + `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: + Returns a 3-tuple of tensors where: + 1. The first element is the packed video latents (with unchanged shape [batch_size, patch_seq_len, + hidden_dim]) with the conditions applied + 2. The second element is the packed conditioning mask with conditioning strengths applied + 3. The third element holds the clean conditioning latents. + """ + # Latents-like tensor which holds the clean conditioning latents + clean_latents = torch.zeros_like(latents) + for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices): + num_cond_tokens = cond.size(1) + start_token_idx = latent_idx * latent_height * latent_width + end_token_idx = start_token_idx + num_cond_tokens + + # Overwrite the portion of latents starting with start_token_idx with the condition + latents[:, start_token_idx:end_token_idx] = cond + conditioning_mask[:, start_token_idx:end_token_idx] = strength + clean_latents[:, start_token_idx:end_token_idx] = cond + + return latents, conditioning_mask, clean_latents + + def prepare_latents( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 1.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width) + + if latents is not None: + # Latents are expected to be unpacked (5D) with shape [B, F, C, H, W] + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + else: + # NOTE: we set the initial latents to zeros rather a sample from the standard Gaussian prior because we + # will sample from the prior later once we have calculated the conditioning mask + latents = torch.zeros(shape, device=device, dtype=dtype) + + conditioning_mask = latents.new_zeros(mask_shape) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) # [B, seq_len, 1] + + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape[:2]: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape[:2] + (num_channels_latents,)}." + ) + + if isinstance(generator, list): + logger.warning( + f"{self.__class__.__name__} does not support using a list of generators. The first generator in the" + f" list will be used for all (pseudo-)random operations." + ) + generator = generator[0] + + condition_frames, condition_strengths, condition_indices = self.preprocess_conditions( + conditions, height, width, num_frames, device=device + ) + condition_latents = [] + for condition_tensor in condition_frames: + condition_latent = retrieve_latents( + self.vae.encode(condition_tensor), generator=generator, sample_mode="argmax" + ) + condition_latent = self._normalize_latents( + condition_latent, self.vae.latents_mean, self.vae.latents_std + ).to(device=device, dtype=dtype) + condition_latent = self._pack_latents( + condition_latent, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + condition_latents.append(condition_latent) + + # NOTE: following the I2V pipeline, we return a conditioning mask. The original LTX 2 code uses a denoising + # mask, which is the inverse of the conditioning mask (`denoise_mask = 1 - conditioning_mask`) + latents, conditioning_mask, clean_latents = self.apply_visual_conditioning( + latents, + conditioning_mask, + condition_latents, + condition_strengths, + condition_indices, + latent_height=latent_height, + latent_width=latent_width, + ) + + # Sample from the standard Gaussian prior (or an intermediate Gaussian distribution if noise_scale < 1.0). + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + scaled_mask = (1.0 - conditioning_mask) * noise_scale + # Add noise to the `latents` so that it is at the noise level specified by `noise_scale`. + latents = noise * scaled_mask + latents * (1 - scaled_mask) + + return latents, conditioning_mask, clean_latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + # latents expected to be unpacked (4D) with shape [B, C, L, M] + latents = self._pack_audio_latents(latents) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + conditions: LTX2VideoCondition | list[LTX2VideoCondition] | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[float] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + conditions (`List[LTXVideoCondition], *optional*`): + The list of frame-conditioning items for the video generation. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `None`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the + sigma schedule. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + latents=latents, + audio_latents=audio_latents, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if conditions is not None and not isinstance(conditions, list): + conditions = [conditions] + + # Infer noise scale: first (largest) sigma value if using custom sigmas, else 1.0 + if noise_scale is None: + noise_scale = sigmas[0] if sigmas is not None else 1.0 + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask, clean_latents = self.prepare_latents( + conditions, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_num_frames, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask.squeeze(-1)) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG + bsz = noise_pred_video.size(0) + sigma = self.scheduler.sigmas[i] + # Convert the noise_pred_video velocity model prediction into a sample (x0) prediction + denoised_sample = latents - noise_pred_video * sigma + # Apply the (packed) conditioning mask to the denoised (x0) sample and clean conditioning. The + # conditioning mask contains conditioning strengths from 0 (always use denoised sample) to 1 (always + # use conditions), with intermediate values specifying how strongly to follow the conditions. + denoised_sample_cond = ( + denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz] + ).to(noise_pred_video.dtype) + # Convert the denoised (x0) sample back to a velocity for the scheduler + denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0] + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py new file mode 100644 index 0000000000000000000000000000000000000000..83ba2cd7c685ff7212a0a876b9178aaa366ccdee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -0,0 +1,1317 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + def prepare_latents( + self, + image: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + if latents.ndim == 5: + # conditioning_mask needs to the same shape as latents in two stages generation. + batch_size, _, num_frames, height, width = latents.shape + mask_shape = (batch_size, 1, num_frames, height, width) + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + else: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[int] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..00d81dfd11c37e23cd39e8191444bd9435673095 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -0,0 +1,425 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTX2Video +from ...utils import get_logger, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..ltx.pipeline_output import LTXPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .latent_upsampler import LTX2LatentUpsamplerModel + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="pil", + ... return_dict=False, + ... ) + + >>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + ... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) + >>> upsample_pipe.vae.enable_tiling() + >>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16) + + >>> video = upsample_pipe( + ... video=video, + ... width=768, + ... height=512, + ... output_type="np", + ... return_dict=False, + ... )[0] + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTX2LatentUpsamplePipeline(DiffusionPipeline): + model_cpu_offload_seq = "vae->latent_upsampler" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + latent_upsampler: LTX2LatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: torch.Tensor | None = None, + batch_size: int = 1, + num_frames: int = 121, + height: int = 512, + width: int = 768, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 3: + # Convert token seq [B, S, D] to latent video [B, C, F, H, W] + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = self._unpack_latents( + latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size + ) + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here + # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually + smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially + when controlling dynamic behavior with a `compression` factor. + + Args: + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + def check_inputs(self, video, height, width, latents, tone_map_compression_ratio): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + if not (0 <= tone_map_compression_ratio <= 1): + raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: list[PipelineImageInput] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + latents: torch.Tensor | None = None, + latents_normalized: bool = False, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + adain_factor: float = 0.0, + tone_map_compression_ratio: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + video (`list[PipelineImageInput]`, *optional*) + The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be + supplied. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the input video (not the generated video, which will have a larger resolution). + width (`int`, *optional*, defaults to `768`): + The width in pixels of the input video (not the generated video, which will have a larger resolution). + num_frames (`int`, *optional*, defaults to `121`): + The number of frames in the input video. + spatial_patch_size (`int`, *optional*, defaults to `1`): + The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary. + temporal_patch_size (`int`, *optional*, defaults to `1`): + The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is + necessary. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a + patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, + latent_channels, latent_frames, latent_height, latent_width)`. + latents_normalized (`bool`, *optional*, defaults to `False`) + If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If + `True`, the `latents` will be denormalized before being supplied to the latent upsampler. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + adain_factor (`float`, *optional*, defaults to `0.0`): + Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents. + Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed. + tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`): + The compression strength for tone mapping, which will reduce the dynamic range of the latent values. + This is useful for regularizing high-variance latents or for conditioning outputs during generation. + Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to + the full compression effect. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is the upsampled video. + """ + + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + tone_map_compression_ratio=tone_map_compression_ratio, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents_supplied = latents is not None + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + spatial_patch_size=spatial_patch_size, + temporal_patch_size=temporal_patch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + if latents_supplied and latents_normalized: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + if tone_map_compression_ratio > 0.0: + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + + if output_type == "latent": + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..02891219ee6768bb0a76b2a36c11e3541b51081d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTX2PipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + audio (`torch.Tensor`, `np.ndarray`): + TODO + """ + + frames: torch.Tensor + audio: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f80469817fe6d32884820b838b62789a1088f106 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,6 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/vocoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..551c3ac5980f1daeed53ce73bb4c7f29f4391d77 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ltx2/vocoder.py @@ -0,0 +1,158 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4], + upsample_factors: list[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..580e1f37f30a4bfc91878692c9e0004a233aa64e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lucy_edit import LucyEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..926a884be8e4cbb8b989a75fbd0f586f01887d11 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/pipeline_lucy_edit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/pipeline_lucy_edit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4116aee1d977e75558240238dadf8dd31c8f5cd3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/pipeline_lucy_edit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb96b8e4a833c62d10f881d8776c12c41b79a19 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/pipeline_lucy_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..392af492b702a53bff9dce509ce5fb31dee7c51a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/pipeline_lucy_edit.py @@ -0,0 +1,733 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Decart AI Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications by Decart AI Team: +# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension. + +import html +from typing import Any, Callable + +import regex as re +import torch +from PIL import Image +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LucyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> from typing import list + + >>> import torch + >>> from PIL import Image + + >>> from diffusers import AutoencoderKLWan, LucyEditPipeline + >>> from diffusers.utils import export_to_video, load_video + + >>> # Arguments + >>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4" + >>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights." + >>> negative_prompt = "" + >>> num_frames = 81 + >>> height = 480 + >>> width = 832 + + + >>> # Load video + >>> def convert_video(video: list[Image.Image]) -> list[Image.Image]: + ... video = load_video(url)[:num_frames] + ... video = [video[i].resize((width, height)) for i in range(num_frames)] + ... return video + + + >>> video = load_video(url, convert_method=convert_video) + + >>> # Load model + >>> model_id = "decart-ai/Lucy-Edit-Dev" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Generate video + >>> output = pipe( + ... prompt=prompt, + ... video=video, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + + >>> # Export video + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for video-to-video generation using Lucy Edit. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + """ + + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: WanTransformer3DModel | None = None, + transformer_2: WanTransformer3DModel | None = None, + boundary_ratio: float | None = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + transformer_2=transformer_2, + ) + self.register_to_config(boundary_ratio=boundary_ratio) + self.register_to_config(expand_timesteps=expand_timesteps) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + video, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if video is None: + raise ValueError("`video` is required, received None.") + + def prepare_latents( + self, + video: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = ( + (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + ) + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + # Prepare noise latents + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # Prepare condition latents + condition_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video + ] + + condition_latents = torch.cat(condition_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + condition_latents = (condition_latents - latents_mean) * latents_std + + # Check shapes + assert latents.shape == condition_latents.shape, ( + f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input." + ) + + return latents, condition_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: list[Image.Image], + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`list[Image.Image]`): + The video to use as the condition for the video generation. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~LucyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + video, + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.config.out_channels + if self.transformer is not None + else self.transformer_2.config.out_channels + ) + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition_latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + # latent_model_input = latents.to(transformer_dtype) + latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype) + # latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype) + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LucyPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..197ce194f475e5b738bd19cb6201fe943ad603ea --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lucy/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LucyPipelineOutput(BaseOutput): + r""" + Output class for Lucy pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a19dc7e94641f59cad8588e867412829cd4fa793 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lumina import LuminaPipeline, LuminaText2ImgPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..533b95833379ed07b014348946442390687f3eea Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__pycache__/pipeline_lumina.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__pycache__/pipeline_lumina.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bab9a271f3a9c3f41a568495c42b178bad642f74 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/__pycache__/pipeline_lumina.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/pipeline_lumina.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/pipeline_lumina.py new file mode 100644 index 0000000000000000000000000000000000000000..cc123218f4ee5f7452c7dcb77b7de41865101c4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina/pipeline_lumina.py @@ -0,0 +1,957 @@ +# Copyright 2025 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...models.embeddings import get_2d_rotary_pos_embed_lumina +from ...models.transformers.lumina_nextdit2d import LuminaNextDiT2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LuminaPipeline + + >>> pipe = LuminaPipeline.from_pretrained("Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LuminaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Lumina-T2I. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`GemmaPreTrainedModel`]): + Frozen Gemma text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + ] + + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: GemmaPreTrainedModel, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.max_sequence_length = 256 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clean_caption: bool | None = False, + max_length: int | None = None, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + pad_to_multiple_of=8, + max_length=self.max_sequence_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {self.max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + clean_caption=clean_caption, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + # Padding negative prompt to the same length with prompt + prompt_max_length = prompt_embeds.shape[1] + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=prompt_max_length, + truncation=True, + return_tensors="pt", + ) + negative_text_input_ids = negative_text_inputs.input_ids.to(device) + negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) + # Get the negative prompt embeddings + negative_prompt_embeds = self.text_encoder( + negative_text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ) + + negative_dtype = self.text_encoder.dtype + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + _, seq_len, _ = negative_prompt_embeds.shape + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + width: int | None = None, + height: int | None = None, + num_inference_steps: int = 30, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] = None, + sigmas: list[float] = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = True, + max_sequence_length: int = 256, + scaling_watershed: float | None = 1.0, + proportional_attn: bool | None = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to 120): + Maximum sequence length to use with the `prompt`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + + cross_attention_kwargs = {} + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if proportional_attn: + cross_attention_kwargs["base_sequence_length"] = (self.default_image_size // 16) ** 2 + + scaling_factor = math.sqrt(width * height / self.default_image_size**2) + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + + # prepare image_rotary_emb for positional encoding + # dynamic scaling_factor for different resolution. + # NOTE: For `Time-aware` denosing mechanism from Lumina-Next + # https://huggingface.co/papers/2406.18583, Sec 2.3 + # NOTE: We should compute different image_rotary_emb with different timestep. + if current_timestep[0] < scaling_watershed: + linear_factor = scaling_factor + ntk_factor = 1.0 + else: + linear_factor = 1.0 + ntk_factor = scaling_factor + image_rotary_emb = get_2d_rotary_pos_embed_lumina( + self.transformer.head_dim, + 384, + 384, + linear_factor=linear_factor, + ntk_factor=ntk_factor, + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_mask=prompt_attention_mask, + image_rotary_emb=image_rotary_emb, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # perform guidance scale + # NOTE: For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + if do_classifier_free_guidance: + noise_pred_eps, noise_pred_rest = noise_pred[:, :3], noise_pred[:, 3:] + noise_pred_cond_eps, noise_pred_uncond_eps = torch.split( + noise_pred_eps, len(noise_pred_eps) // 2, dim=0 + ) + noise_pred_half = noise_pred_uncond_eps + guidance_scale * ( + noise_pred_cond_eps - noise_pred_uncond_eps + ) + noise_pred_eps = torch.cat([noise_pred_half, noise_pred_half], dim=0) + + noise_pred = torch.cat([noise_pred_eps, noise_pred_rest], dim=1) + noise_pred, _ = noise_pred.chunk(2, dim=0) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + progress_bar.update() + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +class LuminaText2ImgPipeline(LuminaPipeline): + def __init__( + self, + transformer: LuminaNextDiT2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: GemmaPreTrainedModel, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + ): + deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead." + deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message) + super().__init__( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d6bfeb0d5838b1ce22104dc737106b23bc59f3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..739699cc7297f959b597582a3cd10eb321a3c949 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__pycache__/pipeline_lumina2.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__pycache__/pipeline_lumina2.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7517dd9471eb81ce24488efb10e804c4b38ea92 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/__pycache__/pipeline_lumina2.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/pipeline_lumina2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/pipeline_lumina2.py new file mode 100644 index 0000000000000000000000000000000000000000..576d3e8d948636179285b6353ba8328cca21fb8a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -0,0 +1,818 @@ +# Copyright 2025 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...image_processor import VaeImageProcessor +from ...loaders import Lumina2LoraLoaderMixin +from ...models import AutoencoderKL +from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Lumina2Pipeline + + >>> pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Lumina-T2I. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Gemma2PreTrainedModel`]): + Frozen Gemma2 text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + transformer: Lumina2Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Gemma2PreTrainedModel, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts." + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device | None = None, + max_sequence_length: int = 256, + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + system_prompt: str | None = None, + max_sequence_length: int = 256, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if system_prompt is None: + system_prompt = self.system_prompt + if prompt is not None: + prompt = [system_prompt + " " + p for p in prompt] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + width: int | None = None, + height: int | None = None, + num_inference_steps: int = 30, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] = None, + sigmas: list[float] = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + system_prompt: str | None = None, + cfg_trunc_ratio: float = 1.0, + cfg_normalization: bool = True, + max_sequence_length: int = 256, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + system_prompt (`str`, *optional*): + The system prompt to use for the image generation. + cfg_trunc_ratio (`float`, *optional*, defaults to `1.0`): + The ratio of the timestep interval to apply normalization-based guidance scale. + cfg_normalization (`bool`, *optional*, defaults to `True`): + Whether to apply normalization-based guidance scale. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + system_prompt=system_prompt, + ) + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # compute whether apply classifier-free truncation on this timestep + do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - t / self.scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latents, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + # perform normalization-based guidance scale on a truncated timestep interval + if self.do_classifier_free_guidance and not do_classifier_free_truncation: + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=current_timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + # apply normalization after classifier-free guidance + if cfg_normalization: + cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm) + else: + noise_pred = noise_pred_cond + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +class Lumina2Text2ImgPipeline(Lumina2Pipeline): + def __init__( + self, + transformer: Lumina2Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Gemma2PreTrainedModel, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + ): + deprecation_message = "`Lumina2Text2ImgPipeline` has been renamed to `Lumina2Pipeline` and will be removed in a future version. Please use `Lumina2Pipeline` instead." + deprecate("diffusers.pipelines.lumina2.pipeline_lumina2.Lumina2Text2ImgPipeline", "0.34", deprecation_message) + super().__init__( + transformer=transformer, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..168a8276be4e24d4983cd48f75d0bd16f08abf35 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["marigold_image_processing"] = ["MarigoldImageProcessor"] + _import_structure["pipeline_marigold_depth"] = ["MarigoldDepthOutput", "MarigoldDepthPipeline"] + _import_structure["pipeline_marigold_intrinsics"] = ["MarigoldIntrinsicsOutput", "MarigoldIntrinsicsPipeline"] + _import_structure["pipeline_marigold_normals"] = ["MarigoldNormalsOutput", "MarigoldNormalsPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .marigold_image_processing import MarigoldImageProcessor + from .pipeline_marigold_depth import MarigoldDepthOutput, MarigoldDepthPipeline + from .pipeline_marigold_intrinsics import MarigoldIntrinsicsOutput, MarigoldIntrinsicsPipeline + from .pipeline_marigold_normals import MarigoldNormalsOutput, MarigoldNormalsPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..278a77b8613066a24caa3302f94b4f2b9fd46928 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/marigold_image_processing.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/marigold_image_processing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..071c1cc7ffdc547bf39c36b3a774e25145c4480b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/marigold_image_processing.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_depth.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_depth.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b166b0bbabe3c2473a5dee307a39a18d5222f387 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_depth.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_intrinsics.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_intrinsics.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcbbdce679b6fb569f2310cf526ad6967729b224 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_intrinsics.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bd9afa51dfc477b9abf21c212bac1094bdc4135 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/marigold_image_processing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/marigold_image_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..4a76ef9892e436f04dc184f7aeda0f106e7be1b8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/marigold_image_processing.py @@ -0,0 +1,672 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# Copyright 2024-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldcomputervision.github.io +# -------------------------------------------------------------------------- +from typing import Any + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from PIL import Image + +from ... import ConfigMixin +from ...configuration_utils import register_to_config +from ...image_processor import PipelineImageInput +from ...utils import CONFIG_NAME, logging +from ...utils.import_utils import is_matplotlib_available + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MarigoldImageProcessor(ConfigMixin): + config_name = CONFIG_NAME + + @register_to_config + def __init__( + self, + vae_scale_factor: int = 8, + do_normalize: bool = True, + do_range_check: bool = True, + ): + super().__init__() + + @staticmethod + def expand_tensor_or_array(images: torch.Tensor | np.ndarray) -> torch.Tensor | np.ndarray: + """ + Expand a tensor or array to a specified number of images. + """ + if isinstance(images, np.ndarray): + if images.ndim == 2: # [H,W] -> [1,H,W,1] + images = images[None, ..., None] + if images.ndim == 3: # [H,W,C] -> [1,H,W,C] + images = images[None] + elif isinstance(images, torch.Tensor): + if images.ndim == 2: # [H,W] -> [1,1,H,W] + images = images[None, None] + elif images.ndim == 3: # [1,H,W] -> [1,1,H,W] + images = images[None] + else: + raise ValueError(f"Unexpected input type: {type(images)}") + return images + + @staticmethod + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: + """ + Convert a PyTorch tensor to a NumPy image. + """ + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: + """ + Convert a NumPy image to a PyTorch tensor. + """ + if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") + if np.issubdtype(images.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") + if np.issubdtype(images.dtype, bool): + raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + @staticmethod + def resize_antialias( + image: torch.Tensor, size: tuple[int, int], mode: str, is_aa: bool | None = None + ) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + antialias = is_aa and mode in ("bilinear", "bicubic") + image = F.interpolate(image, size, mode=mode, antialias=antialias) + + return image + + @staticmethod + def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + max_orig = max(h, w) + new_h = h * max_edge_sz // max_orig + new_w = w * max_edge_sz // max_orig + + if new_h == 0 or new_w == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") + + image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) + + return image + + @staticmethod + def pad_image(image: torch.Tensor, align: int) -> tuple[torch.Tensor, tuple[int, int]]: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + h, w = image.shape[-2:] + ph, pw = -h % align, -w % align + + image = F.pad(image, (0, pw, 0, ph), mode="replicate") + + return image, (ph, pw) + + @staticmethod + def unpad_image(image: torch.Tensor, padding: tuple[int, int]) -> torch.Tensor: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.dim() != 4: + raise ValueError(f"Invalid input dimensions; shape={image.shape}.") + + ph, pw = padding + uh = None if ph == 0 else -ph + uw = None if pw == 0 else -pw + + image = image[:, :, :uh, :uw] + + return image + + @staticmethod + def load_image_canonical( + image: torch.Tensor | np.ndarray | Image.Image, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ) -> tuple[torch.Tensor, int]: + if isinstance(image, Image.Image): + image = np.array(image) + + image_dtype_max = None + if isinstance(image, (np.ndarray, torch.Tensor)): + image = MarigoldImageProcessor.expand_tensor_or_array(image) + if image.ndim != 4: + raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") + if isinstance(image, np.ndarray): + if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): + raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") + if np.issubdtype(image.dtype, np.complexfloating): + raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") + if np.issubdtype(image.dtype, bool): + raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") + if np.issubdtype(image.dtype, np.unsignedinteger): + image_dtype_max = np.iinfo(image.dtype).max + image = image.astype(np.float32) # because torch does not have unsigned dtypes beyond torch.uint8 + image = MarigoldImageProcessor.numpy_to_pt(image) + + if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: + if image.dtype != torch.uint8: + raise ValueError(f"Image dtype={image.dtype} is not supported.") + image_dtype_max = 255 + + if not torch.is_tensor(image): + raise ValueError(f"Input type unsupported: {type(image)}.") + + if image.shape[1] == 1: + image = image.repeat(1, 3, 1, 1) # [N,1,H,W] -> [N,3,H,W] + if image.shape[1] != 3: + raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") + + image = image.to(device=device, dtype=dtype) + + if image_dtype_max is not None: + image = image / image_dtype_max + + return image + + @staticmethod + def check_image_values_range(image: torch.Tensor) -> None: + if not torch.is_tensor(image): + raise ValueError(f"Invalid input type={type(image)}.") + if not torch.is_floating_point(image): + raise ValueError(f"Invalid input dtype={image.dtype}.") + if image.min().item() < 0.0 or image.max().item() > 1.0: + raise ValueError("Input image data is partially outside of the [0,1] range.") + + def preprocess( + self, + image: PipelineImageInput, + processing_resolution: int | None = None, + resample_method_input: str = "bilinear", + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ): + if isinstance(image, list): + images = None + for i, img in enumerate(image): + img = self.load_image_canonical(img, device, dtype) # [N,3,H,W] + if images is None: + images = img + else: + if images.shape[2:] != img.shape[2:]: + raise ValueError( + f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " + f"{images.shape[2:]}" + ) + images = torch.cat((images, img), dim=0) + image = images + del images + else: + image = self.load_image_canonical(image, device, dtype) # [N,3,H,W] + + original_resolution = image.shape[2:] + + if self.config.do_range_check: + self.check_image_values_range(image) + + if self.config.do_normalize: + image = image * 2.0 - 1.0 + + if processing_resolution is not None and processing_resolution > 0: + image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) # [N,3,PH,PW] + + image, padding = self.pad_image(image, self.config.vae_scale_factor) # [N,3,PPH,PPW] + + return image, padding, original_resolution + + @staticmethod + def colormap( + image: np.ndarray | torch.Tensor, + cmap: str = "Spectral", + bytes: bool = False, + _force_method: str | None = None, + ) -> np.ndarray | torch.Tensor: + """ + Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the + behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", + "binary") without having to install or import matplotlib. For all other cases, the function will attempt to use + the native implementation. + + Args: + image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. + cmap: Colormap name. + bytes: Whether to return the output as uint8 or floating point image. + _force_method: + Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom + implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). + + Returns: + An RGB-colorized tensor corresponding to the input image. + """ + if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): + raise ValueError("Argument must be a numpy array or torch tensor.") + if _force_method not in (None, "matplotlib", "custom"): + raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") + + supported_cmaps = { + "binary": [ + (1.0, 1.0, 1.0), + (0.0, 0.0, 0.0), + ], + "Spectral": [ # Taken from matplotlib/_cm.py + (0.61960784313725492, 0.003921568627450980, 0.25882352941176473), # 0.0 -> [0] + (0.83529411764705885, 0.24313725490196078, 0.30980392156862746), + (0.95686274509803926, 0.42745098039215684, 0.2627450980392157), + (0.99215686274509807, 0.68235294117647061, 0.38039215686274508), + (0.99607843137254903, 0.8784313725490196, 0.54509803921568623), + (1.0, 1.0, 0.74901960784313726), + (0.90196078431372551, 0.96078431372549022, 0.59607843137254901), + (0.6705882352941176, 0.8666666666666667, 0.64313725490196083), + (0.4, 0.76078431372549016, 0.6470588235294118), + (0.19607843137254902, 0.53333333333333333, 0.74117647058823533), + (0.36862745098039218, 0.30980392156862746, 0.63529411764705879), # 1.0 -> [K-1] + ], + } + + def method_matplotlib(image, cmap, bytes=False): + if is_matplotlib_available(): + import matplotlib + else: + return None + + arg_is_pt, device = torch.is_tensor(image), None + if arg_is_pt: + image, device = image.cpu().numpy(), image.device + + if cmap not in matplotlib.colormaps: + raise ValueError( + f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" + ) + + cmap = matplotlib.colormaps[cmap] + out = cmap(image, bytes=bytes) # [?,4] + out = out[..., :3] # [?,3] + + if arg_is_pt: + out = torch.tensor(out, device=device) + + return out + + def method_custom(image, cmap, bytes=False): + arg_is_np = isinstance(image, np.ndarray) + if arg_is_np: + image = torch.tensor(image) + if image.dtype == torch.uint8: + image = image.float() / 255 + else: + image = image.float() + + is_cmap_reversed = cmap.endswith("_r") + if is_cmap_reversed: + cmap = cmap[:-2] + + if cmap not in supported_cmaps: + raise ValueError( + f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." + ) + + cmap = supported_cmaps[cmap] + if is_cmap_reversed: + cmap = cmap[::-1] + cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) # [K,3] + K = cmap.shape[0] + + pos = image.clamp(min=0, max=1) * (K - 1) + left = pos.long() + right = (left + 1).clamp(max=K - 1) + + d = (pos - left.float()).unsqueeze(-1) + left_colors = cmap[left] + right_colors = cmap[right] + + out = (1 - d) * left_colors + d * right_colors + + if bytes: + out = (out * 255).to(torch.uint8) + + if arg_is_np: + out = out.numpy() + + return out + + if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": + return method_custom(image, cmap, bytes) + + out = None + if _force_method != "custom": + out = method_matplotlib(image, cmap, bytes) + + if _force_method == "matplotlib" and out is None: + raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") + + if out is None: + out = method_custom(image, cmap, bytes) + + return out + + @staticmethod + def visualize_depth( + depth: PIL.Image.Image + | np.ndarray + | torch.Tensor + | list[PIL.Image.Image] + | list[np.ndarray] + | list[torch.Tensor], + val_min: float = 0.0, + val_max: float = 1.0, + color_map: str = "Spectral", + ) -> list[PIL.Image.Image]: + """ + Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. + + Args: + depth (`PIL.Image.Image | np.ndarray | torch.Tensor | list[PIL.Image.Image, list[np.ndarray], + list[torch.Tensor]]`): Depth maps. + val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. + val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. + color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel + depth prediction into colored representation. + + Returns: `list[PIL.Image.Image]` with depth maps visualization. + """ + if val_max <= val_min: + raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") + + def visualize_depth_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if isinstance(img, PIL.Image.Image): + if img.mode != "I;16": + raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") + img = np.array(img).astype(np.float32) / (2**16 - 1) + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + if not torch.is_floating_point(img): + raise ValueError(f"{prefix}: unexpected dtype={img.dtype}.") + else: + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) # [H,W,3] + img = PIL.Image.fromarray(img.cpu().numpy()) + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, torch.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): # [N,1,H,W] + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def export_depth_to_16bit_png( + depth: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor], + val_min: float = 0.0, + val_max: float = 1.0, + ) -> list[PIL.Image.Image]: + def export_depth_to_16bit_png_one(img, idx=None): + prefix = "Depth" + (f"[{idx}]" if idx else "") + if not isinstance(img, np.ndarray) and not torch.is_tensor(img): + raise ValueError(f"{prefix}: unexpected type={type(img)}.") + if img.ndim != 2: + raise ValueError(f"{prefix}: unexpected shape={img.shape}.") + if torch.is_tensor(img): + img = img.cpu().numpy() + if not np.issubdtype(img.dtype, np.floating): + raise ValueError(f"{prefix}: unexpected dtype={img.dtype}.") + if val_min != 0.0 or val_max != 1.0: + img = (img - val_min) / (val_max - val_min) + img = (img * (2**16 - 1)).astype(np.uint16) + img = PIL.Image.fromarray(img, mode="I;16") + return img + + if depth is None or isinstance(depth, list) and any(o is None for o in depth): + raise ValueError("Input depth is `None`") + if isinstance(depth, (np.ndarray, torch.Tensor)): + depth = MarigoldImageProcessor.expand_tensor_or_array(depth) + if isinstance(depth, np.ndarray): + depth = MarigoldImageProcessor.numpy_to_pt(depth) # [N,H,W,1] -> [N,1,H,W] + if not (depth.ndim == 4 and depth.shape[1] == 1): + raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") + return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] + elif isinstance(depth, list): + return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] + else: + raise ValueError(f"Unexpected input type: {type(depth)}") + + @staticmethod + def visualize_normals( + normals: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor], + flip_x: bool = False, + flip_y: bool = False, + flip_z: bool = False, + ) -> list[PIL.Image.Image]: + """ + Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. + + Args: + normals (`np.ndarray | torch.Tensor | list[np.ndarray, list[torch.Tensor]]`): + Surface normals. + flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. + Default direction is right. + flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. + Default direction is top. + flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. + Default direction is facing the observer. + + Returns: `list[PIL.Image.Image]` with surface normals visualization. + """ + flip_vec = None + if any((flip_x, flip_y, flip_z)): + flip_vec = torch.tensor( + [ + (-1) ** flip_x, + (-1) ** flip_y, + (-1) ** flip_z, + ], + dtype=torch.float32, + ) + + def visualize_normals_one(img, idx=None): + img = img.permute(1, 2, 0) + if flip_vec is not None: + img *= flip_vec.to(img.device) + img = (img + 1.0) * 0.5 + img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() + img = PIL.Image.fromarray(img) + return img + + if normals is None or isinstance(normals, list) and any(o is None for o in normals): + raise ValueError("Input normals is `None`") + if isinstance(normals, (np.ndarray, torch.Tensor)): + normals = MarigoldImageProcessor.expand_tensor_or_array(normals) + if isinstance(normals, np.ndarray): + normals = MarigoldImageProcessor.numpy_to_pt(normals) # [N,3,H,W] + if not (normals.ndim == 4 and normals.shape[1] == 3): + raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + elif isinstance(normals, list): + return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] + else: + raise ValueError(f"Unexpected input type: {type(normals)}") + + @staticmethod + def visualize_intrinsics( + prediction: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor], + target_properties: dict[str, Any], + color_map: str | dict[str, str] = "binary", + ) -> list[dict[str, PIL.Image.Image]]: + """ + Visualizes intrinsic image decomposition, such as predictions of the `MarigoldIntrinsicsPipeline`. + + Args: + prediction (`np.ndarray | torch.Tensor | list[np.ndarray, list[torch.Tensor]]`): + Intrinsic image decomposition. + target_properties (`dict[str, Any]`): + Decomposition properties. Expected entries: `target_names: list[str]` and a dictionary with keys + `prediction_space: str`, `sub_target_names: list[str | Null]` (must have 3 entries, null for missing + modalities), `up_to_scale: bool`, one for each target and sub-target. + color_map (`str | dict[str, str]`, *optional*, defaults to `"Spectral"`): + Color map used to convert a single-channel predictions into colored representations. When a dictionary + is passed, each modality can be colored with its own color map. + + Returns: `list[dict[str, PIL.Image.Image]]` with intrinsic image decomposition visualization. + """ + if "target_names" not in target_properties: + raise ValueError("Missing `target_names` in target_properties") + if not isinstance(color_map, str) and not ( + isinstance(color_map, dict) + and all(isinstance(k, str) and isinstance(v, str) for k, v in color_map.items()) + ): + raise ValueError("`color_map` must be a string or a dictionary of strings") + n_targets = len(target_properties["target_names"]) + + def visualize_targets_one(images, idx=None): + # img: [T, 3, H, W] + out = {} + for target_name, img in zip(target_properties["target_names"], images): + img = img.permute(1, 2, 0) # [H, W, 3] + prediction_space = target_properties[target_name].get("prediction_space", "srgb") + if prediction_space == "stack": + sub_target_names = target_properties[target_name]["sub_target_names"] + if len(sub_target_names) != 3 or any( + not (isinstance(s, str) or s is None) for s in sub_target_names + ): + raise ValueError(f"Unexpected target sub-names {sub_target_names} in {target_name}") + for i, sub_target_name in enumerate(sub_target_names): + if sub_target_name is None: + continue + sub_img = img[:, :, i] + sub_prediction_space = target_properties[sub_target_name].get("prediction_space", "srgb") + if sub_prediction_space == "linear": + sub_up_to_scale = target_properties[sub_target_name].get("up_to_scale", False) + if sub_up_to_scale: + sub_img = sub_img / max(sub_img.max().item(), 1e-6) + sub_img = sub_img ** (1 / 2.2) + cmap_name = ( + color_map if isinstance(color_map, str) else color_map.get(sub_target_name, "binary") + ) + sub_img = MarigoldImageProcessor.colormap(sub_img, cmap=cmap_name, bytes=True) + sub_img = PIL.Image.fromarray(sub_img.cpu().numpy()) + out[sub_target_name] = sub_img + elif prediction_space == "linear": + up_to_scale = target_properties[target_name].get("up_to_scale", False) + if up_to_scale: + img = img / max(img.max().item(), 1e-6) + img = img ** (1 / 2.2) + elif prediction_space == "srgb": + pass + img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() + img = PIL.Image.fromarray(img) + out[target_name] = img + return out + + if prediction is None or isinstance(prediction, list) and any(o is None for o in prediction): + raise ValueError("Input prediction is `None`") + if isinstance(prediction, (np.ndarray, torch.Tensor)): + prediction = MarigoldImageProcessor.expand_tensor_or_array(prediction) + if isinstance(prediction, np.ndarray): + prediction = MarigoldImageProcessor.numpy_to_pt(prediction) # [N*T,3,H,W] + if not (prediction.ndim == 4 and prediction.shape[1] == 3 and prediction.shape[0] % n_targets == 0): + raise ValueError(f"Unexpected input shape={prediction.shape}, expecting [N*T,3,H,W].") + N_T, _, H, W = prediction.shape + N = N_T // n_targets + prediction = prediction.reshape(N, n_targets, 3, H, W) + return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)] + elif isinstance(prediction, list): + return [visualize_targets_one(img, idx) for idx, img in enumerate(prediction)] + else: + raise ValueError(f"Unexpected input type: {type(prediction)}") + + @staticmethod + def visualize_uncertainty( + uncertainty: np.ndarray | torch.Tensor | list[np.ndarray] | list[torch.Tensor], + saturation_percentile=95, + ) -> list[PIL.Image.Image]: + """ + Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline`, `MarigoldNormalsPipeline`, or + `MarigoldIntrinsicsPipeline`. + + Args: + uncertainty (`np.ndarray | torch.Tensor | list[np.ndarray, list[torch.Tensor]]`): + Uncertainty maps. + saturation_percentile (`int`, *optional*, defaults to `95`): + Specifies the percentile uncertainty value visualized with maximum intensity. + + Returns: `list[PIL.Image.Image]` with uncertainty visualization. + """ + + def visualize_uncertainty_one(img, idx=None): + prefix = "Uncertainty" + (f"[{idx}]" if idx else "") + if img.min() < 0: + raise ValueError(f"{prefix}: unexpected data range, min={img.min()}.") + img = img.permute(1, 2, 0) # [H,W,C] + img = img.squeeze(2).cpu().numpy() # [H,W] or [H,W,3] + saturation_value = np.percentile(img, saturation_percentile) + img = np.clip(img * 255 / saturation_value, 0, 255) + img = img.astype(np.uint8) + img = PIL.Image.fromarray(img) + return img + + if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): + raise ValueError("Input uncertainty is `None`") + if isinstance(uncertainty, (np.ndarray, torch.Tensor)): + uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) + if isinstance(uncertainty, np.ndarray): + uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) # [N,C,H,W] + if not (uncertainty.ndim == 4 and uncertainty.shape[1] in (1, 3)): + raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,C,H,W] with C in (1,3).") + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + elif isinstance(uncertainty, list): + return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] + else: + raise ValueError(f"Unexpected input type: {type(uncertainty)}") diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..a81d1c51742cb1be5cd9a0f84940a86d8830ebe9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_depth.py @@ -0,0 +1,827 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# Copyright 2024-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldcomputervision.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from functools import partial +from typing import Any + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.import_utils import is_scipy_available +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldDepthPipeline.from_pretrained( +... "prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> depth = pipe(image) + +>>> vis = pipe.image_processor.visualize_depth(depth.prediction) +>>> vis[0].save("einstein_depth.png") + +>>> depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) +>>> depth_16bit[0].save("einstein_depth_16bit.png") +``` +""" + + +@dataclass +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted depth maps with values in the range [0, 1]. The shape is `numimages × 1 × height × width` for + `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 × + height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is `numimages * numensemble × 4 × latentheight × latentwidth`. + """ + + prediction: np.ndarray | torch.Tensor + uncertainty: None | np.ndarray | torch.Tensor + latent: None | torch.Tensor + + +class MarigoldDepthPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + scale_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in + the model config. When used together with the `shift_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + shift_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in + the model config. When used together with the `scale_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("depth", "disparity") + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: DDIMScheduler | LCMScheduler, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: str | None = None, + scale_invariant: bool | None = True, + shift_invariant: bool | None = True, + default_denoising_steps: int | None = None, + default_processing_resolution: int | None = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + prediction_type=prediction_type, + scale_invariant=scale_invariant, + shift_invariant=shift_invariant, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: dict[str, Any] | None, + latents: torch.Tensor | None, + generator: torch.Generator | list[torch.Generator] | None, + output_type: str, + output_uncertainty: bool, + ) -> int: + actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + if actual_vae_scale_factor != self.vae_scale_factor: + raise ValueError( + f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})." + ) + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size > 1 and (self.scale_invariant or self.shift_invariant) and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use ensembling.") + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("mean", "median"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'mean'` or `'median'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: int | None = None, + ensemble_size: int = 1, + processing_resolution: int | None = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: dict[str, Any] | None = None, + latents: torch.Tensor | list[torch.Tensor] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`), + `list[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Higher values result in measurable improvements and visual degradation. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in + every pixel location, can be either `"median"` or `"mean"`. + - regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that + pulls the aligned predictions to the unit range from 0 to 1. + - max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to + `scipy.optimize.minimize` function, `options` argument. + - tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the + tolerance is reached. + - max_res (`int`, *optional*, defaults to `None`): Resolution at which the alignment is performed; + `None` matches the `processing_resolution`. + latents (`torch.Tensor`, or `list[torch.Tensor]`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `list[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldDepthOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldDepthOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,4,h,w] + + if XLA_AVAILABLE: + xm.mark_step() + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = torch.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + dim=0, + ) # [N*E,1,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,1,PH,PW] + prediction = [ + self.ensemble_depth( + prediction[i], + self.scale_invariant, + self.shift_invariant, + output_uncertainty, + **(ensembling_kwargs or {}), + ) + for i in range(num_images) + ] # [ [[1,1,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,1,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N,1,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldDepthOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + def prepare_latents( + self, + image: torch.Tensor, + latents: torch.Tensor | None, + generator: torch.Generator | None, + ensemble_size: int, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = prediction.mean(dim=1, keepdim=True) # [B,1,H,W] + prediction = torch.clip(prediction, -1.0, 1.0) # [B,1,H,W] + prediction = (prediction + 1.0) / 2.0 + + return prediction # [B,1,H,W] + + @staticmethod + def ensemble_depth( + depth: torch.Tensor, + scale_invariant: bool = True, + shift_invariant: bool = True, + output_uncertainty: bool = False, + reduction: str = "median", + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + max_res: int = 1024, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Ensembles the depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the + number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for + depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The + alignment happens when the predictions have one or more degrees of freedom, that is when they are either + affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only + `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) + alignment is skipped and only ensembling is performed. + + Args: + depth (`torch.Tensor`): + Input ensemble depth maps. + scale_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as scale-invariant. + shift_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as shift-invariant. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"median"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and + `"median"`. + regularizer_strength (`float`, *optional*, defaults to `0.02`): + Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. + max_iter (`int`, *optional*, defaults to `2`): + Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` + argument. + tol (`float`, *optional*, defaults to `1e-3`): + Alignment solver tolerance. The solver stops when the tolerance is reached. + max_res (`int`, *optional*, defaults to `1024`): + Resolution at which the alignment is performed; `None` matches the `processing_resolution`. + Returns: + A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: + `(1, 1, H, W)`. + """ + if depth.dim() != 4 or depth.shape[1] != 1: + raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") + if reduction not in ("mean", "median"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + if not scale_invariant and shift_invariant: + raise ValueError("Pure shift-invariant ensembling is not supported.") + + def init_param(depth: torch.Tensor): + init_min = depth.reshape(ensemble_size, -1).min(dim=1).values + init_max = depth.reshape(ensemble_size, -1).max(dim=1).values + + if scale_invariant and shift_invariant: + init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) + init_t = -init_s * init_min + param = torch.cat((init_s, init_t)).cpu().numpy() + elif scale_invariant: + init_s = 1.0 / init_max.clamp(min=1e-6) + param = init_s.cpu().numpy() + else: + raise ValueError("Unrecognized alignment.") + param = param.astype(np.float64) + + return param + + def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: + if scale_invariant and shift_invariant: + s, t = np.split(param, 2) + s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) + t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + t + elif scale_invariant: + s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + else: + raise ValueError("Unrecognized alignment.") + return out + + def ensemble( + depth_aligned: torch.Tensor, return_uncertainty: bool = False + ) -> tuple[torch.Tensor, torch.Tensor | None]: + uncertainty = None + if reduction == "mean": + prediction = torch.mean(depth_aligned, dim=0, keepdim=True) + if return_uncertainty: + uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) + elif reduction == "median": + prediction = torch.median(depth_aligned, dim=0, keepdim=True).values + if return_uncertainty: + uncertainty = torch.median(torch.abs(depth_aligned - prediction), dim=0, keepdim=True).values + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty + + def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: + cost = 0.0 + depth_aligned = align(depth, param) + + for i, j in torch.combinations(torch.arange(ensemble_size)): + diff = depth_aligned[i] - depth_aligned[j] + cost += (diff**2).mean().sqrt().item() + + if regularizer_strength > 0: + prediction, _ = ensemble(depth_aligned, return_uncertainty=False) + err_near = prediction.min().abs().item() + err_far = (1.0 - prediction.max()).abs().item() + cost += (err_near + err_far) * regularizer_strength + + return cost + + def compute_param(depth: torch.Tensor): + import scipy + + depth_to_align = depth.to(torch.float32) + if max_res is not None and max(depth_to_align.shape[2:]) > max_res: + depth_to_align = MarigoldImageProcessor.resize_to_max_edge(depth_to_align, max_res, "nearest-exact") + + param = init_param(depth_to_align) + + res = scipy.optimize.minimize( + partial(cost_fn, depth=depth_to_align), + param, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) + + return res.x + + requires_aligning = scale_invariant or shift_invariant + ensemble_size = depth.shape[0] + + if requires_aligning: + param = compute_param(depth) + depth = align(depth, param) + + depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) + + depth_max = depth.max() + if scale_invariant and shift_invariant: + depth_min = depth.min() + elif scale_invariant: + depth_min = 0 + else: + raise ValueError("Unrecognized alignment.") + depth_range = (depth_max - depth_min).clamp(min=1e-6) + depth = (depth - depth_min) / depth_range + if output_uncertainty: + uncertainty /= depth_range + + return depth, uncertainty # [1,1,H,W], [1,1,H,W] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..9488d8f5c9b81d3a3a0413e2cb7f094189c3f9cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py @@ -0,0 +1,721 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# Copyright 2024-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldcomputervision.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained( +... "prs-eth/marigold-iid-appearance-v1-1", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> intrinsics = pipe(image) + +>>> vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties) +>>> vis[0]["albedo"].save("einstein_albedo.png") +>>> vis[0]["roughness"].save("einstein_roughness.png") +>>> vis[0]["metallicity"].save("einstein_metallicity.png") +``` +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained( +... "prs-eth/marigold-iid-lighting-v1-1", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> intrinsics = pipe(image) + +>>> vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties) +>>> vis[0]["albedo"].save("einstein_albedo.png") +>>> vis[0]["shading"].save("einstein_shading.png") +>>> vis[0]["residual"].save("einstein_residual.png") +``` +""" + + +@dataclass +class MarigoldIntrinsicsOutput(BaseOutput): + """ + Output class for Marigold Intrinsic Image Decomposition pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted image intrinsics with values in the range [0, 1]. The shape is `(numimages * numtargets) × 3 × + height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for `np.ndarray`, + where `numtargets` corresponds to the number of predicted target modalities of the intrinsic image + decomposition. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `(numimages * + numtargets) × 3 × height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for + `np.ndarray`. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is `(numimages * numensemble) × (numtargets * 4) × latentheight × latentwidth`. + """ + + prediction: np.ndarray | torch.Tensor + uncertainty: None | np.ndarray | torch.Tensor + latent: None | torch.Tensor + + +class MarigoldIntrinsicsPipeline(DiffusionPipeline): + """ + Pipeline for Intrinsic Image Decomposition (IID) using the Marigold method: + https://marigoldcomputervision.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the targets latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + target_properties (`dict[str, Any]`, *optional*): + Properties of the predicted modalities, such as `target_names`, a `list[str]` used to define the number, + order and names of the predicted modalities, and any other metadata that may be required to interpret the + predictions. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("intrinsics",) + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: DDIMScheduler | LCMScheduler, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: str | None = None, + target_properties: dict[str, Any] | None = None, + default_denoising_steps: int | None = None, + default_processing_resolution: int | None = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + prediction_type=prediction_type, + target_properties=target_properties, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.target_properties = target_properties + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @property + def n_targets(self): + return self.unet.config.out_channels // self.vae.config.latent_channels + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: dict[str, Any] | None, + latents: torch.Tensor | None, + generator: torch.Generator | list[torch.Generator] | None, + output_type: str, + output_uncertainty: bool, + ) -> int: + actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + if actual_vae_scale_factor != self.vae_scale_factor: + raise ValueError( + f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})." + ) + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("median", "mean"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'median'` or `'mean'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.unet.config.out_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: int | None = None, + ensemble_size: int = 1, + processing_resolution: int | None = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: dict[str, Any] | None = None, + latents: torch.Tensor | list[torch.Tensor] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`), + `list[torch.Tensor]`: An input image or images used as an input for the intrinsic decomposition task. + For arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is + possible by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Higher values result in measurable improvements and visual degradation. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"median"`): Defines the ensembling function applied in + every pixel location, can be either `"median"` or `"mean"`. + latents (`torch.Tensor`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `list[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldIntrinsicsOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldIntrinsicsOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldIntrinsicsOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` contains each input image encoded into latent space and + # replicated `E` times. The variable `pred_latent` contains latents initialization, where the latent space is + # replicated `T` times relative to the single latent space of `image_latent`, where `T` is the number of the + # predicted targets. The latents can be either generated (see `generator` to ensure reproducibility), or passed + # explicitly via the `latents` argument. The latter can be set outside the pipeline code. This behavior can be + # achieved by setting the `output_latent` argument to `True`. The latent space dimensions are `(h, w)`. Encoding + # into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,T*4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,T*4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,(1+T)*4,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,T*4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,T*4,h,w] + + if XLA_AVAILABLE: + xm.mark_step() + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,T*4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + pred_latent_for_decoding = pred_latent.reshape( + num_images * ensemble_size * self.n_targets, self.vae.config.latent_channels, *pred_latent.shape[2:] + ) # [N*E*T,4,PPH,PPW] + prediction = torch.cat( + [ + self.decode_prediction(pred_latent_for_decoding[i : i + batch_size]) + for i in range(0, pred_latent_for_decoding.shape[0], batch_size) + ], + dim=0, + ) # [N*E*T,3,PPH,PPW] + + del pred_latent_for_decoding + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E*T,3,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N*T` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N*T` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape( + num_images, ensemble_size, self.n_targets, *prediction.shape[1:] + ) # [N,E,T,3,PH,PW] + prediction = [ + self.ensemble_intrinsics(prediction[i], output_uncertainty, **(ensembling_kwargs or {})) + for i in range(num_images) + ] # [ [[T,3,PH,PW], [T,3,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[T,3,PH,PW], ... ], [[T,3,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N*T,3,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N*T,3,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N*T,3,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N*T,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N*T,H,W,3] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N*T,H,W,3] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldIntrinsicsOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + def prepare_latents( + self, + image: torch.Tensor, + latents: torch.Tensor | None, + generator: torch.Generator | None, + ensemble_size: int, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + N_E, C, H, W = image_latent.shape + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + (N_E, self.n_targets * C, H, W), + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,T*4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = torch.clip(prediction, -1.0, 1.0) # [B,3,H,W] + prediction = (prediction + 1.0) / 2.0 + + return prediction # [B,3,H,W] + + @staticmethod + def ensemble_intrinsics( + targets: torch.Tensor, + output_uncertainty: bool = False, + reduction: str = "median", + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Ensembles the intrinsic decomposition represented by the `targets` tensor with expected shape `(B, T, 3, H, + W)`, where B is the number of ensemble members for a given prediction of size `(H x W)`, and T is the number of + predicted targets. + + Args: + targets (`torch.Tensor`): + Input ensemble of intrinsic image decomposition maps. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"mean"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"median"` and + `"mean"`. + + Returns: + A tensor of aligned and ensembled intrinsic decomposition maps with shape `(T, 3, H, W)` and optionally a + tensor of uncertainties of shape `(T, 3, H, W)`. + """ + if targets.dim() != 5 or targets.shape[2] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,T,3,H,W]; got {targets.shape}.") + if reduction not in ("median", "mean"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + + B, T, _, H, W = targets.shape + uncertainty = None + if reduction == "mean": + prediction = torch.mean(targets, dim=0) # [T,3,H,W] + if output_uncertainty: + uncertainty = torch.std(targets, dim=0) # [T,3,H,W] + elif reduction == "median": + prediction = torch.median(targets, dim=0, keepdim=True).values # [1,T,3,H,W] + if output_uncertainty: + uncertainty = torch.abs(targets - prediction) # [B,T,3,H,W] + uncertainty = torch.median(uncertainty, dim=0).values # [T,3,H,W] + prediction = prediction.squeeze(0) # [T,3,H,W] + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_normals.py new file mode 100644 index 0000000000000000000000000000000000000000..3f94ce441232fcdd950e4ce2d6515d939c3b1cbf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/marigold/pipeline_marigold_normals.py @@ -0,0 +1,704 @@ +# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. +# Copyright 2024-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# More information and citation instructions are available on the +# Marigold project website: https://marigoldcomputervision.github.io +# -------------------------------------------------------------------------- +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput +from ...models import ( + AutoencoderKL, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + LCMScheduler, +) +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .marigold_image_processing import MarigoldImageProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ +Examples: +```py +>>> import diffusers +>>> import torch + +>>> pipe = diffusers.MarigoldNormalsPipeline.from_pretrained( +... "prs-eth/marigold-normals-v1-1", variant="fp16", torch_dtype=torch.float16 +... ).to("cuda") + +>>> image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg") +>>> normals = pipe(image) + +>>> vis = pipe.image_processor.visualize_normals(normals.prediction) +>>> vis[0].save("einstein_normals.png") +``` +""" + + +@dataclass +class MarigoldNormalsOutput(BaseOutput): + """ + Output class for Marigold monocular normals prediction pipeline. + + Args: + prediction (`np.ndarray`, `torch.Tensor`): + Predicted normals with values in the range [-1, 1]. The shape is `numimages × 3 × height × width` for + `torch.Tensor` or `numimages × height × width × 3` for `np.ndarray`. + uncertainty (`None`, `np.ndarray`, `torch.Tensor`): + Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 × + height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`. + latent (`None`, `torch.Tensor`): + Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. + The shape is `numimages * numensemble × 4 × latentheight × latentwidth`. + """ + + prediction: np.ndarray | torch.Tensor + uncertainty: None | np.ndarray | torch.Tensor + latent: None | torch.Tensor + + +class MarigoldNormalsPipeline(DiffusionPipeline): + """ + Pipeline for monocular normals estimation using the Marigold method: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the normals latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent + representations. + scheduler (`DDIMScheduler` or `LCMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + prediction_type (`str`, *optional*): + Type of predictions made by the model. + use_full_z_range (`bool`, *optional*): + Whether the normals predicted by this model utilize the full range of the Z dimension, or only its positive + half. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + supported_prediction_types = ("normals",) + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: DDIMScheduler | LCMScheduler, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + prediction_type: str | None = None, + use_full_z_range: bool | None = True, + default_denoising_steps: int | None = None, + default_processing_resolution: int | None = None, + ): + super().__init__() + + if prediction_type not in self.supported_prediction_types: + logger.warning( + f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: " + f"{self.supported_prediction_types}." + ) + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + self.register_to_config( + prediction_type=prediction_type, + use_full_z_range=use_full_z_range, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + self.use_full_z_range = use_full_z_range + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + + self.empty_text_embedding = None + + self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def check_inputs( + self, + image: PipelineImageInput, + num_inference_steps: int, + ensemble_size: int, + processing_resolution: int, + resample_method_input: str, + resample_method_output: str, + batch_size: int, + ensembling_kwargs: dict[str, Any] | None, + latents: torch.Tensor | None, + generator: torch.Generator | list[torch.Generator] | None, + output_type: str, + output_uncertainty: bool, + ) -> int: + actual_vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + if actual_vae_scale_factor != self.vae_scale_factor: + raise ValueError( + f"`vae_scale_factor` computed at initialization ({self.vae_scale_factor}) differs from the actual one ({actual_vae_scale_factor})." + ) + if num_inference_steps is None: + raise ValueError("`num_inference_steps` is not specified and could not be resolved from the model config.") + if num_inference_steps < 1: + raise ValueError("`num_inference_steps` must be positive.") + if ensemble_size < 1: + raise ValueError("`ensemble_size` must be positive.") + if ensemble_size == 2: + logger.warning( + "`ensemble_size` == 2 results are similar to no ensembling (1); " + "consider increasing the value to at least 3." + ) + if ensemble_size == 1 and output_uncertainty: + raise ValueError( + "Computing uncertainty by setting `output_uncertainty=True` also requires setting `ensemble_size` " + "greater than 1." + ) + if processing_resolution is None: + raise ValueError( + "`processing_resolution` is not specified and could not be resolved from the model config." + ) + if processing_resolution < 0: + raise ValueError( + "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " + "downsampled processing." + ) + if processing_resolution % self.vae_scale_factor != 0: + raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") + if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_input` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): + raise ValueError( + "`resample_method_output` takes string values compatible with PIL library: " + "nearest, nearest-exact, bilinear, bicubic, area." + ) + if batch_size < 1: + raise ValueError("`batch_size` must be positive.") + if output_type not in ["pt", "np"]: + raise ValueError("`output_type` must be one of `pt` or `np`.") + if latents is not None and generator is not None: + raise ValueError("`latents` and `generator` cannot be used together.") + if ensembling_kwargs is not None: + if not isinstance(ensembling_kwargs, dict): + raise ValueError("`ensembling_kwargs` must be a dictionary.") + if "reduction" in ensembling_kwargs and ensembling_kwargs["reduction"] not in ("closest", "mean"): + raise ValueError("`ensembling_kwargs['reduction']` can be either `'closest'` or `'mean'`.") + + # image checks + num_images = 0 + W, H = None, None + if not isinstance(image, list): + image = [image] + for i, img in enumerate(image): + if isinstance(img, np.ndarray) or torch.is_tensor(img): + if img.ndim not in (2, 3, 4): + raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") + H_i, W_i = img.shape[-2:] + N_i = 1 + if img.ndim == 4: + N_i = img.shape[0] + elif isinstance(img, Image.Image): + W_i, H_i = img.size + N_i = 1 + else: + raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") + if W is None: + W, H = W_i, H_i + elif (W, H) != (W_i, H_i): + raise ValueError( + f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" + ) + num_images += N_i + + # latents checks + if latents is not None: + if not torch.is_tensor(latents): + raise ValueError("`latents` must be a torch.Tensor.") + if latents.dim() != 4: + raise ValueError(f"`latents` has unsupported dimensions or shape: {latents.shape}.") + + if processing_resolution > 0: + max_orig = max(H, W) + new_H = H * processing_resolution // max_orig + new_W = W * processing_resolution // max_orig + if new_H == 0 or new_W == 0: + raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]") + W, H = new_W, new_H + w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor + h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor + shape_expected = (num_images * ensemble_size, self.vae.config.latent_channels, h, w) + + if latents.shape != shape_expected: + raise ValueError(f"`latents` has unexpected shape={latents.shape} expected={shape_expected}.") + + # generator checks + if generator is not None: + if isinstance(generator, list): + if len(generator) != num_images * ensemble_size: + raise ValueError( + "The number of generators must match the total number of ensemble members for all input images." + ) + if not all(g.device.type == generator[0].device.type for g in generator): + raise ValueError("`generator` device placement is not consistent in the list.") + elif not isinstance(generator, torch.Generator): + raise ValueError(f"Unsupported generator type: {type(generator)}.") + + return num_images + + @torch.compiler.disable + def progress_bar(self, iterable=None, total=None, desc=None, leave=True): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + progress_bar_config = dict(**self._progress_bar_config) + progress_bar_config["desc"] = progress_bar_config.get("desc", desc) + progress_bar_config["leave"] = progress_bar_config.get("leave", leave) + if iterable is not None: + return tqdm(iterable, **progress_bar_config) + elif total is not None: + return tqdm(total=total, **progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + num_inference_steps: int | None = None, + ensemble_size: int = 1, + processing_resolution: int | None = None, + match_input_resolution: bool = True, + resample_method_input: str = "bilinear", + resample_method_output: str = "bilinear", + batch_size: int = 1, + ensembling_kwargs: dict[str, Any] | None = None, + latents: torch.Tensor | list[torch.Tensor] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str = "np", + output_uncertainty: bool = False, + output_latent: bool = False, + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline. + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`), + `list[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For + arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible + by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or + three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the + same width and height. + num_inference_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. + ensemble_size (`int`, defaults to `1`): + Number of ensemble predictions. Higher values result in measurable improvements and visual degradation. + processing_resolution (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, matches the larger input image dimension. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. + match_input_resolution (`bool`, *optional*, defaults to `True`): + When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer + side of the output will equal to `processing_resolution`. + resample_method_input (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize input images to `processing_resolution`. The accepted values are: + `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + resample_method_output (`str`, *optional*, defaults to `"bilinear"`): + Resampling method used to resize output predictions to match the input resolution. The accepted values + are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. + batch_size (`int`, *optional*, defaults to `1`): + Batch size; only matters when setting `ensemble_size` or passing a tensor of images. + ensembling_kwargs (`dict`, *optional*, defaults to `None`) + Extra dictionary with arguments for precise ensembling control. The following options are available: + - reduction (`str`, *optional*, defaults to `"closest"`): Defines the ensembling function applied in + every pixel location, can be either `"closest"` or `"mean"`. + latents (`torch.Tensor`, *optional*, defaults to `None`): + Latent noise tensors to replace the random initialization. These can be taken from the previous + function call's output. + generator (`torch.Generator`, or `list[torch.Generator]`, *optional*, defaults to `None`): + Random number generator object to ensure reproducibility. + output_type (`str`, *optional*, defaults to `"np"`): + Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted + values are: `"np"` (numpy array) or `"pt"` (torch tensor). + output_uncertainty (`bool`, *optional*, defaults to `False`): + When enabled, the output's `uncertainty` field contains the predictive uncertainty map, provided that + the `ensemble_size` argument is set to a value above 2. + output_latent (`bool`, *optional*, defaults to `False`): + When enabled, the output's `latent` field contains the latent codes corresponding to the predictions + within the ensemble. These codes can be saved, modified, and used for subsequent calls with the + `latents` argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.marigold.MarigoldNormalsOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.marigold.MarigoldNormalsOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.marigold.MarigoldNormalsOutput`] is returned, otherwise a + `tuple` is returned where the first element is the prediction, the second element is the uncertainty + (or `None`), and the third is the latent (or `None`). + """ + + # 0. Resolving variables. + device = self._execution_device + dtype = self.dtype + + # Model-specific optimal default values leading to fast and reasonable results. + if num_inference_steps is None: + num_inference_steps = self.default_denoising_steps + if processing_resolution is None: + processing_resolution = self.default_processing_resolution + + # 1. Check inputs. + num_images = self.check_inputs( + image, + num_inference_steps, + ensemble_size, + processing_resolution, + resample_method_input, + resample_method_output, + batch_size, + ensembling_kwargs, + latents, + generator, + output_type, + output_uncertainty, + ) + + # 2. Prepare empty text conditioning. + # Model invocation: self.tokenizer, self.text_encoder. + if self.empty_text_embedding is None: + prompt = "" + text_inputs = self.tokenizer( + prompt, + padding="do_not_pad", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] + + # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, + # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where + # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are + # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` + # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of + # operation and leads to the most reasonable results. Using the native image resolution or any other processing + # resolution can lead to loss of either fine details or global context in the output predictions. + image, padding, original_resolution = self.image_processor.preprocess( + image, processing_resolution, resample_method_input, device, dtype + ) # [N,3,PPH,PPW] + + # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` + # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. + # Latents of each such predictions across all input images and all ensemble members are represented in the + # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded + # into latent space and replicated `E` times. The latents can be either generated (see `generator` to ensure + # reproducibility), or passed explicitly via the `latents` argument. The latter can be set outside the pipeline + # code. This behavior can be achieved by setting the `output_latent` argument to `True`. The latent space + # dimensions are `(h, w)`. Encoding into latent space happens in batches of size `batch_size`. + # Model invocation: self.vae.encoder. + image_latent, pred_latent = self.prepare_latents( + image, latents, generator, ensemble_size, batch_size + ) # [N*E,4,h,w], [N*E,4,h,w] + + del image + + batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( + batch_size, 1, 1 + ) # [B,1024,2] + + # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. + # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and + # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by + # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded + # model. + # Model invocation: self.unet. + pred_latents = [] + + for i in self.progress_bar( + range(0, num_images * ensemble_size, batch_size), leave=True, desc="Marigold predictions..." + ): + batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] + batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w] + effective_batch_size = batch_image_latent.shape[0] + text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] + + self.scheduler.set_timesteps(num_inference_steps, device=device) + for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."): + batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] + noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] + batch_pred_latent = self.scheduler.step( + noise, t, batch_pred_latent, generator=generator + ).prev_sample # [B,4,h,w] + + if XLA_AVAILABLE: + xm.mark_step() + + pred_latents.append(batch_pred_latent) + + pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] + + del ( + pred_latents, + image_latent, + batch_empty_text_embedding, + batch_image_latent, + batch_pred_latent, + text, + batch_latent, + noise, + ) + + # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, + # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. + # Model invocation: self.vae.decoder. + prediction = torch.cat( + [ + self.decode_prediction(pred_latent[i : i + batch_size]) + for i in range(0, pred_latent.shape[0], batch_size) + ], + dim=0, + ) # [N*E,3,PPH,PPW] + + if not output_latent: + pred_latent = None + + # 7. Remove padding. The output shape is (PH, PW). + prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW] + + # 8. Ensemble and compute uncertainty (when `output_uncertainty` is set). This code treats each of the `N` + # groups of `E` ensemble predictions independently. For each group it computes an ensembled prediction of shape + # `(PH, PW)` and an optional uncertainty map of the same dimensions. After computing this pair of outputs for + # each group independently, it stacks them respectively into batches of `N` almost final predictions and + # uncertainty maps. + uncertainty = None + if ensemble_size > 1: + prediction = prediction.reshape(num_images, ensemble_size, *prediction.shape[1:]) # [N,E,3,PH,PW] + prediction = [ + self.ensemble_normals(prediction[i], output_uncertainty, **(ensembling_kwargs or {})) + for i in range(num_images) + ] # [ [[1,3,PH,PW], [1,1,PH,PW]], ... ] + prediction, uncertainty = zip(*prediction) # [[1,3,PH,PW], ... ], [[1,1,PH,PW], ... ] + prediction = torch.cat(prediction, dim=0) # [N,3,PH,PW] + if output_uncertainty: + uncertainty = torch.cat(uncertainty, dim=0) # [N,1,PH,PW] + else: + uncertainty = None + + # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the + # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. + # After upsampling, the native resolution normal maps are renormalized to unit length to reduce the artifacts. + # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by + # setting the `resample_method_output` parameter (e.g., to `"nearest"`). + if match_input_resolution: + prediction = self.image_processor.resize_antialias( + prediction, original_resolution, resample_method_output, is_aa=False + ) # [N,3,H,W] + prediction = self.normalize_normals(prediction) # [N,3,H,W] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.resize_antialias( + uncertainty, original_resolution, resample_method_output, is_aa=False + ) # [N,1,H,W] + + # 10. Prepare the final outputs. + if output_type == "np": + prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3] + if uncertainty is not None and output_uncertainty: + uncertainty = self.image_processor.pt_to_numpy(uncertainty) # [N,H,W,1] + + # 11. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (prediction, uncertainty, pred_latent) + + return MarigoldNormalsOutput( + prediction=prediction, + uncertainty=uncertainty, + latent=pred_latent, + ) + + # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents + def prepare_latents( + self, + image: torch.Tensor, + latents: torch.Tensor | None, + generator: torch.Generator | None, + ensemble_size: int, + batch_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + def retrieve_latents(encoder_output): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + image_latent = torch.cat( + [ + retrieve_latents(self.vae.encode(image[i : i + batch_size])) + for i in range(0, image.shape[0], batch_size) + ], + dim=0, + ) # [N,4,h,w] + image_latent = image_latent * self.vae.config.scaling_factor + image_latent = image_latent.repeat_interleave(ensemble_size, dim=0) # [N*E,4,h,w] + + pred_latent = latents + if pred_latent is None: + pred_latent = randn_tensor( + image_latent.shape, + generator=generator, + device=image_latent.device, + dtype=image_latent.dtype, + ) # [N*E,4,h,w] + + return image_latent, pred_latent + + def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: + if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: + raise ValueError( + f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." + ) + + prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] + + prediction = torch.clip(prediction, -1.0, 1.0) + + if not self.use_full_z_range: + prediction[:, 2, :, :] *= 0.5 + prediction[:, 2, :, :] += 0.5 + + prediction = self.normalize_normals(prediction) # [B,3,H,W] + + return prediction # [B,3,H,W] + + @staticmethod + def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + + norm = torch.norm(normals, dim=1, keepdim=True) + normals /= norm.clamp(min=eps) + + return normals + + @staticmethod + def ensemble_normals( + normals: torch.Tensor, output_uncertainty: bool, reduction: str = "closest" + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Ensembles the normals maps represented by the `normals` tensor with expected shape `(B, 3, H, W)`, where B is + the number of ensemble members for a given prediction of size `(H x W)`. + + Args: + normals (`torch.Tensor`): + Input ensemble normals maps. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"closest"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"closest"` and + `"mean"`. + + Returns: + A tensor of aligned and ensembled normals maps with shape `(1, 3, H, W)` and optionally a tensor of + uncertainties of shape `(1, 1, H, W)`. + """ + if normals.dim() != 4 or normals.shape[1] != 3: + raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.") + if reduction not in ("closest", "mean"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + + mean_normals = normals.mean(dim=0, keepdim=True) # [1,3,H,W] + mean_normals = MarigoldNormalsPipeline.normalize_normals(mean_normals) # [1,3,H,W] + + sim_cos = (mean_normals * normals).sum(dim=1, keepdim=True) # [E,1,H,W] + sim_cos = sim_cos.clamp(-1, 1) # required to avoid NaN in uncertainty with fp16 + + uncertainty = None + if output_uncertainty: + uncertainty = sim_cos.arccos() # [E,1,H,W] + uncertainty = uncertainty.mean(dim=0, keepdim=True) / np.pi # [1,1,H,W] + + if reduction == "mean": + return mean_normals, uncertainty # [1,3,H,W], [1,1,H,W] + + closest_indices = sim_cos.argmax(dim=0, keepdim=True) # [1,1,H,W] + closest_indices = closest_indices.repeat(1, 3, 1, 1) # [1,3,H,W] + closest_normals = torch.gather(normals, 0, closest_indices) # [1,3,H,W] + + return closest_normals, uncertainty # [1,3,H,W], [1,1,H,W] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fd4da9fd36e28479df6258f7de108b63d3c2bb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_mochi"] = ["MochiPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_mochi import MochiPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/pipeline_mochi.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/pipeline_mochi.py new file mode 100644 index 0000000000000000000000000000000000000000..e8acc0a75e4d471572db5f2901d7ce04a5bd88ef --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/pipeline_mochi.py @@ -0,0 +1,769 @@ +# Copyright 2025 Genmo and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import Mochi1LoraLoaderMixin +from ...models import AutoencoderKLMochi, MochiTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MochiPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MochiPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + >>> pipe.enable_vae_tiling() + >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." + >>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] + >>> export_to_video(frames, "mochi.mp4") + ``` +""" + + +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): + r""" + The mochi pipeline for text-to-video generation. + + Reference: https://github.com/genmoai/models + + Args: + transformer ([`MochiTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLMochi`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLMochi, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: MochiTransformer3DModel, + force_zeros_for_empty_prompt: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + # TODO: determine these scaling factors from model parameters + self.vae_spatial_scale_factor = 8 + self.vae_temporal_scale_factor = 6 + self.patch_size = 2 + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256 + ) + self.default_height = 480 + self.default_width = 848 + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + # The original Mochi implementation zeros out empty negative prompts + # but this can lead to overflow when placing the entire pipeline under the autocast context + # adding this here so that we can enable zeroing prompts if necessary + if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""): + text_input_ids = torch.zeros_like(text_input_ids, device=device) + prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents=None, + ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32) + latents = latents.to(dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_frames: int = 19, + num_inference_steps: int = 64, + timesteps: list[int] = None, + guidance_scale: float = 4.5, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `self.default_height`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `self.default_width`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `19`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `4.5`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `256`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple` + is returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.default_height + width = width or self.default_width + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 5. Prepare timestep + # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 + threshold_noise = 0.025 + sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + sigmas = np.array(sigmas) + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + sigmas, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need + # to make sure we're using the correct non-reversed timestep values. + self._current_timestep = 1000 - t + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + # Mochi CFG + Sampling runs in FP32 + noise_pred = noise_pred.to(torch.float32) + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0] + latents = latents.to(latents_dtype) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + video = latents + else: + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return MochiPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..5068cf930aaa4c77822f663c7fe9698b570a0710 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/mochi/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class MochiPipelineOutput(BaseOutput): + r""" + Output class for Mochi pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/musicldm/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/musicldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed71eeb1d99b28f20f7cd94776c0303208620653 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/musicldm/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_musicldm"] = ["MusicLDMPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_musicldm import MusicLDMPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/musicldm/pipeline_musicldm.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/musicldm/pipeline_musicldm.py new file mode 100644 index 0000000000000000000000000000000000000000..e7747a4f8c3dc47d1d489b1a7c27ba9220bf512b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -0,0 +1,653 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import ( + ClapFeatureExtractor, + ClapModel, + ClapTextModelWithProjection, + RobertaTokenizer, + RobertaTokenizerFast, + SpeechT5HifiGan, +) + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_librosa_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin + + +if is_librosa_available(): + import librosa + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import MusicLDMPipeline + >>> import torch + >>> import scipy + + >>> repo_id = "ucsd-reach/musicldm" + >>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs" + >>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0] + + >>> # save the audio sample as a .wav file + >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio) + ``` +""" + + +class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + _last_supported_version = "0.33.1" + r""" + Pipeline for text-to-audio generation using MusicLDM. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.ClapModel`]): + Frozen text-audio embedding model (`ClapTextModel`), specifically the + [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. + tokenizer ([`PreTrainedTokenizer`]): + A [`~transformers.RobertaTokenizer`] to tokenize text. + feature_extractor ([`~transformers.ClapFeatureExtractor`]): + Feature extractor to compute mel-spectrograms from audio waveforms. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded audio latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + vocoder ([`~transformers.SpeechT5HifiGan`]): + Vocoder of class `SpeechT5HifiGan`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ClapTextModelWithProjection | ClapModel, + tokenizer: RobertaTokenizer | RobertaTokenizerFast, + feature_extractor: ClapFeatureExtractor | None, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vocoder: SpeechT5HifiGan, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + unet=unet, + scheduler=scheduler, + vocoder=vocoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def _encode_prompt( + self, + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device (`torch.device`): + torch device + num_waveforms_per_prompt (`int`): + number of waveforms that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the audio generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLAP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder.get_text_features( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device) + + ( + bs_embed, + seq_len, + ) = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder.get_text_features( + uncond_input_ids, + attention_mask=attention_mask, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform + def mel_spectrogram_to_waveform(self, mel_spectrogram): + if mel_spectrogram.dim() == 4: + mel_spectrogram = mel_spectrogram.squeeze(1) + + waveform = self.vocoder(mel_spectrogram) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + waveform = waveform.cpu().float() + return waveform + + # Copied from diffusers.pipelines.audioldm2.pipeline_audioldm2.AudioLDM2Pipeline.score_waveforms + def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype): + if not is_librosa_available(): + logger.info( + "Automatic scoring of the generated audio waveforms against the input prompt text requires the " + "`librosa` package to resample the generated waveforms. Returning the audios in the order they were " + "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`." + ) + return audio + inputs = self.tokenizer(text, return_tensors="pt", padding=True) + resampled_audio = librosa.resample( + audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate + ) + inputs["input_features"] = self.feature_extractor( + list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate + ).input_features.type(dtype) + inputs = inputs.to(device) + + # compute the audio-text similarity score using the CLAP model + logits_per_text = self.text_encoder(**inputs).logits_per_text + # sort by the highest matching generations per prompt + indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt] + audio = torch.index_select(audio, 0, indices.reshape(-1).cpu()) + return audio + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.check_inputs + def check_inputs( + self, + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor + if audio_length_in_s < min_audio_length_in_s: + raise ValueError( + f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " + f"is {audio_length_in_s}." + ) + + if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: + raise ValueError( + f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " + f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " + f"{self.vae_scale_factor}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(self.vocoder.config.model_in_dim) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its + `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are + lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution + of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device_type = get_device() + device = torch.device(f"{device_type}:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + empty_device_cache() # otherwise we don't see the memory savings (but they probably exist) + + model_sequence = [ + self.text_encoder.text_model, + self.text_encoder.text_projection, + self.unet, + self.vae, + self.vocoder, + self.text_encoder, + ] + + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + audio_length_in_s: float | None = None, + num_inference_steps: int = 200, + guidance_scale: float = 2.0, + negative_prompt: str | list[str] | None = None, + num_waveforms_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + output_type: str | None = "np", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_length_in_s (`int`, *optional*, defaults to 10.24): + The length of the generated audio sample in seconds. + num_inference_steps (`int`, *optional*, defaults to 200): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 2.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, the text encoding + model is a joint text-audio model ([`~transformers.ClapModel`]), and the tokenizer is a + `[~transformers.ClapProcessor]`, then automatic scoring will be performed between the generated outputs + and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text + input in the joint text-audio embedding space. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.AudioPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to spectrogram height + vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate + + if audio_length_in_s is None: + audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor + + height = int(audio_length_in_s / vocoder_upsample_factor) + + original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) + if height % self.vae_scale_factor != 0: + height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor + logger.info( + f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " + f"denoising process." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_length_in_s, + vocoder_upsample_factor, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_waveforms_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_latents, + height, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=None, + class_labels=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + self.maybe_free_model_hooks() + + # 8. Post-processing + if not output_type == "latent": + latents = 1 / self.vae.config.scaling_factor * latents + mel_spectrogram = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = self.mel_spectrogram_to_waveform(mel_spectrogram) + + audio = audio[:, :original_waveform_length] + + # 9. Automatic scoring + if num_waveforms_per_prompt > 1 and prompt is not None: + audio = self.score_waveforms( + text=prompt, + audio=audio, + num_waveforms_per_prompt=num_waveforms_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + if output_type == "np": + audio = audio.numpy() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..557e7c08dc225f14766701b4953e4ff3c9dddb63 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_omnigen"] = ["OmniGenPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_omnigen import OmniGenPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/pipeline_omnigen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/pipeline_omnigen.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf4e30c66199cd715c7667f4e659952aaf55b87 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -0,0 +1,542 @@ +# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import LlamaTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import OmniGenTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torchvision_available(): + from .processor_omnigen import OmniGenMultiModalProcessor + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import OmniGenPipeline + + >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGenPipeline( + DiffusionPipeline, +): + r""" + The OmniGen pipeline for multimodal-to-image generation. + + Reference: https://huggingface.co/papers/2409.11340 + + Args: + transformer ([`OmniGenTransformer2DModel`]): + Autoregressive Transformer architecture for OmniGen. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + tokenizer (`LlamaTokenizer`): + Text tokenizer of class. + [LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer). + """ + + model_cpu_offload_seq = "transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + transformer: OmniGenTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + tokenizer: LlamaTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) is not None else 8 + ) + # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000 + ) + self.default_sample_size = 128 + + def encode_input_images( + self, + input_pixel_values: list[torch.Tensor], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + """ + get the continue embedding of input images by VAE + + Args: + input_pixel_values: normalized pixel of input images + device: + Returns: torch.Tensor + """ + device = device or self._execution_device + dtype = dtype or self.vae.dtype + + input_img_latents = [] + for img in input_pixel_values: + img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor) + input_img_latents.append(img) + return input_img_latents + + def check_inputs( + self, + prompt, + input_images, + height, + width, + use_input_image_size_as_output, + callback_on_step_end_tensor_inputs=None, + ): + if input_images is not None: + if len(input_images) != len(prompt): + raise ValueError( + f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}." + ) + for i in range(len(input_images)): + if input_images[i] is not None: + if not all(f"<|image_{k + 1}|>" in prompt[i] for k in range(len(input_images[i]))): + raise ValueError( + f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`" + ) + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if use_input_image_size_as_output: + if input_images is None or input_images[0] is None: + raise ValueError( + "`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + input_images: PipelineImageInput | list[PipelineImageInput] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + max_input_image_size: int = 1024, + timesteps: list[int] = None, + guidance_scale: float = 2.5, + img_guidance_scale: float = 1.6, + use_input_image_size_as_output: bool = False, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If the input includes images, need to add + placeholders `<|image_i|>` in the prompt to indicate the position of the i-th images. + input_images (`PipelineImageInput` or `list[PipelineImageInput]`, *optional*): + The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + max_input_image_size (`int`, *optional*, defaults to 1024): + the maximum size of input image, which will be used to crop the input image to the maximum size + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 2.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + img_guidance_scale (`float`, *optional*, defaults to 1.6): + Defined as equation 3 in [Instrucpix2pix](https://huggingface.co/papers/2211.09800). + use_input_image_size_as_output (bool, defaults to False): + whether to use the input image size as the output image size, which can be used for single-image input, + e.g., image editing task + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + num_cfg = 2 if input_images is not None else 1 + use_img_cfg = True if input_images is not None else False + if isinstance(prompt, str): + prompt = [prompt] + input_images = [input_images] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + input_images, + height, + width, + use_input_image_size_as_output, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + batch_size = len(prompt) + device = self._execution_device + + # 3. process multi-modal instructions + if max_input_image_size != self.multimodal_processor.max_image_size: + self.multimodal_processor.reset_max_image_size(max_image_size=max_input_image_size) + processed_data = self.multimodal_processor( + prompt, + input_images, + height=height, + width=width, + use_img_cfg=use_img_cfg, + use_input_image_size_as_output=use_input_image_size_as_output, + num_images_per_prompt=num_images_per_prompt, + ) + processed_data["input_ids"] = processed_data["input_ids"].to(device) + processed_data["attention_mask"] = processed_data["attention_mask"].to(device) + processed_data["position_ids"] = processed_data["position_ids"].to(device) + + # 4. Encode input images + input_img_latents = self.encode_input_images(processed_data["input_pixel_values"], device=device) + + # 5. Prepare timesteps + sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps] + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas=sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latents + transformer_dtype = self.transformer.dtype + if use_input_image_size_as_output: + height, width = processed_data["input_pixel_values"][0].shape[-2:] + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (num_cfg + 1)) + latent_model_input = latent_model_input.to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + input_ids=processed_data["input_ids"], + input_img_latents=input_img_latents, + input_image_sizes=processed_data["input_image_sizes"], + attention_mask=processed_data["attention_mask"], + position_ids=processed_data["position_ids"], + return_dict=False, + )[0] + + if num_cfg == 2: + cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0) + noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond) + else: + cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0) + noise_pred = uncond + guidance_scale * (cond - uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + progress_bar.update() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents = latents / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/processor_omnigen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/processor_omnigen.py new file mode 100644 index 0000000000000000000000000000000000000000..8a7f53a771368d850519ee3d728d19c3b218a4ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/omnigen/processor_omnigen.py @@ -0,0 +1,331 @@ +# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import numpy as np +import torch +from PIL import Image + +from ...utils import is_torchvision_available + + +if is_torchvision_available(): + from torchvision import transforms + + +def crop_image(pil_image, max_image_size): + """ + Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and + width are multiples of 16. + """ + while min(*pil_image.size) >= 2 * max_image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + if max(*pil_image.size) > max_image_size: + scale = max_image_size / max(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + if min(*pil_image.size) < 16: + scale = 16 / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y1 = (arr.shape[0] % 16) // 2 + crop_y2 = arr.shape[0] % 16 - crop_y1 + + crop_x1 = (arr.shape[1] % 16) // 2 + crop_x2 = arr.shape[1] % 16 - crop_x1 + + arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2] + return Image.fromarray(arr) + + +class OmniGenMultiModalProcessor: + def __init__(self, text_tokenizer, max_image_size: int = 1024): + self.text_tokenizer = text_tokenizer + self.max_image_size = max_image_size + + self.image_transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + self.collator = OmniGenCollator() + + def reset_max_image_size(self, max_image_size): + self.max_image_size = max_image_size + self.image_transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def process_image(self, image): + if isinstance(image, str): + image = Image.open(image).convert("RGB") + return self.image_transform(image) + + def process_multi_modal_prompt(self, text, input_images): + text = self.add_prefix_instruction(text) + if input_images is None or len(input_images) == 0: + model_inputs = self.text_tokenizer(text) + return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None} + + pattern = r"<\|image_\d+\|>" + prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)] + + for i in range(1, len(prompt_chunks)): + if prompt_chunks[i][0] == 1: + prompt_chunks[i] = prompt_chunks[i][1:] + + image_tags = re.findall(pattern, text) + image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags] + + unique_image_ids = sorted(set(image_ids)) + assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), ( + f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}" + ) + # total images must be the same as the number of image tags + assert len(unique_image_ids) == len(input_images), ( + f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images" + ) + + input_images = [input_images[x - 1] for x in image_ids] + + all_input_ids = [] + img_inx = [] + for i in range(len(prompt_chunks)): + all_input_ids.extend(prompt_chunks[i]) + if i != len(prompt_chunks) - 1: + start_inx = len(all_input_ids) + size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16 + img_inx.append([start_inx, start_inx + size]) + all_input_ids.extend([0] * size) + + return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx} + + def add_prefix_instruction(self, prompt): + user_prompt = "<|user|>\n" + generation_prompt = "Generate an image according to the following instructions\n" + assistant_prompt = "<|assistant|>\n<|diffusion|>" + prompt_suffix = "<|end|>\n" + prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}" + return prompt + + def __call__( + self, + instructions: list[str], + input_images: list[list[str]] = None, + height: int = 1024, + width: int = 1024, + negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.", + use_img_cfg: bool = True, + separate_cfg_input: bool = False, + use_input_image_size_as_output: bool = False, + num_images_per_prompt: int = 1, + ) -> dict: + if isinstance(instructions, str): + instructions = [instructions] + input_images = [input_images] + + input_data = [] + for i in range(len(instructions)): + cur_instruction = instructions[i] + cur_input_images = None if input_images is None else input_images[i] + if cur_input_images is not None and len(cur_input_images) > 0: + cur_input_images = [self.process_image(x) for x in cur_input_images] + else: + cur_input_images = None + assert "<|image_1|>" not in cur_instruction + + mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images) + + neg_mllm_input, img_cfg_mllm_input = None, None + neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None) + if use_img_cfg: + if cur_input_images is not None and len(cur_input_images) >= 1: + img_cfg_prompt = [f"<|image_{i + 1}|>" for i in range(len(cur_input_images))] + img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images) + else: + img_cfg_mllm_input = neg_mllm_input + + for _ in range(num_images_per_prompt): + if use_input_image_size_as_output: + input_data.append( + ( + mllm_input, + neg_mllm_input, + img_cfg_mllm_input, + [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)], + ) + ) + else: + input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width])) + + return self.collator(input_data) + + +class OmniGenCollator: + def __init__(self, pad_token_id=2, hidden_size=3072): + self.pad_token_id = pad_token_id + self.hidden_size = hidden_size + + def create_position(self, attention_mask, num_tokens_for_output_images): + position_ids = [] + text_length = attention_mask.size(-1) + img_length = max(num_tokens_for_output_images) + for mask in attention_mask: + temp_l = torch.sum(mask) + temp_position = [0] * (text_length - temp_l) + list( + range(temp_l + img_length + 1) + ) # we add a time embedding into the sequence, so add one more token + position_ids.append(temp_position) + return torch.LongTensor(position_ids) + + def create_mask(self, attention_mask, num_tokens_for_output_images): + """ + OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within + each image sequence References: [OmniGen](https://huggingface.co/papers/2409.11340) + """ + extended_mask = [] + padding_images = [] + text_length = attention_mask.size(-1) + img_length = max(num_tokens_for_output_images) + seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token + inx = 0 + for mask in attention_mask: + temp_l = torch.sum(mask) + pad_l = text_length - temp_l + + temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1))) + + image_mask = torch.zeros(size=(temp_l + 1, img_length)) + temp_mask = torch.cat([temp_mask, image_mask], dim=-1) + + image_mask = torch.ones(size=(img_length, temp_l + img_length + 1)) + temp_mask = torch.cat([temp_mask, image_mask], dim=0) + + if pad_l > 0: + pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l)) + temp_mask = torch.cat([pad_mask, temp_mask], dim=-1) + + pad_mask = torch.ones(size=(pad_l, seq_len)) + temp_mask = torch.cat([pad_mask, temp_mask], dim=0) + + true_img_length = num_tokens_for_output_images[inx] + pad_img_length = img_length - true_img_length + if pad_img_length > 0: + temp_mask[:, -pad_img_length:] = 0 + temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size)) + else: + temp_padding_imgs = None + + extended_mask.append(temp_mask.unsqueeze(0)) + padding_images.append(temp_padding_imgs) + inx += 1 + return torch.cat(extended_mask, dim=0), padding_images + + def adjust_attention_for_input_images(self, attention_mask, image_sizes): + for b_inx in image_sizes.keys(): + for start_inx, end_inx in image_sizes[b_inx]: + attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1 + + return attention_mask + + def pad_input_ids(self, input_ids, image_sizes): + max_l = max([len(x) for x in input_ids]) + padded_ids = [] + attention_mask = [] + + for i in range(len(input_ids)): + temp_ids = input_ids[i] + temp_l = len(temp_ids) + pad_l = max_l - temp_l + if pad_l == 0: + attention_mask.append([1] * max_l) + padded_ids.append(temp_ids) + else: + attention_mask.append([0] * pad_l + [1] * temp_l) + padded_ids.append([self.pad_token_id] * pad_l + temp_ids) + + if i in image_sizes: + new_inx = [] + for old_inx in image_sizes[i]: + new_inx.append([x + pad_l for x in old_inx]) + image_sizes[i] = new_inx + + return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes + + def process_mllm_input(self, mllm_inputs, target_img_size): + num_tokens_for_output_images = [] + for img_size in target_img_size: + num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16) + + pixel_values, image_sizes = [], {} + b_inx = 0 + for x in mllm_inputs: + if x["pixel_values"] is not None: + pixel_values.extend(x["pixel_values"]) + for size in x["image_sizes"]: + if b_inx not in image_sizes: + image_sizes[b_inx] = [size] + else: + image_sizes[b_inx].append(size) + b_inx += 1 + pixel_values = [x.unsqueeze(0) for x in pixel_values] + + input_ids = [x["input_ids"] for x in mllm_inputs] + padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes) + position_ids = self.create_position(attention_mask, num_tokens_for_output_images) + attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images) + attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes) + + return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes + + def __call__(self, features): + mllm_inputs = [f[0] for f in features] + cfg_mllm_inputs = [f[1] for f in features] + img_cfg_mllm_input = [f[2] for f in features] + target_img_size = [f[3] for f in features] + + if img_cfg_mllm_input[0] is not None: + mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input + target_img_size = target_img_size + target_img_size + target_img_size + else: + mllm_inputs = mllm_inputs + cfg_mllm_inputs + target_img_size = target_img_size + target_img_size + + ( + all_padded_input_ids, + all_position_ids, + all_attention_mask, + all_padding_images, + all_pixel_values, + all_image_sizes, + ) = self.process_mllm_input(mllm_inputs, target_img_size) + + data = { + "input_ids": all_padded_input_ids, + "attention_mask": all_attention_mask, + "position_ids": all_position_ids, + "input_pixel_values": all_pixel_values, + "input_image_sizes": all_image_sizes, + } + return data diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..275061b1f6260c9cccc38d0116e6bf582e3d7e19 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["OvisImagePipelineOutput"] + _import_structure["pipeline_ovis_image"] = ["OvisImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import OvisImagePipelineOutput + from .pipeline_ovis_image import OvisImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e07e02a5f5c94d476ca5498df8d33eee6bb14764 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1503c69ca4e2108fa98f39de3b4cfae8b237078 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/pipeline_ovis_image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/pipeline_ovis_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f9d20ceb81a744804f25549403f3593db33f951 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/__pycache__/pipeline_ovis_image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..4daecf3b2197565c973c1efb9d887a27677fbc9b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class OvisImagePipelineOutput(BaseOutput): + """ + Output class for Ovis-Image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/pipeline_ovis_image.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ff8227f27e7e15f2664e74413fd60bbdf13102 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -0,0 +1,668 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2TokenizerFast, Qwen3Model + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, OvisImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import OvisImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import OvisImagePipeline + + >>> pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = 'A creative 3D artistic render where the text "OVIS-IMAGE" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist\'s canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail.' + >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, guidance_scale=5.0).images[0] + >>> image.save("ovis_image.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OvisImagePipeline( + DiffusionPipeline, +): + r""" + The Ovis-Image pipeline for text-to-image generation. + + Reference: https://github.com/AIDC-AI/Ovis-Image + + Args: + transformer ([`OvisImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3Model`]): + Text encoder of class + [Qwen3Model](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3Model). + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen3Model, + tokenizer: Qwen2TokenizerFast, + transformer: OvisImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Ovis-Image latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.system_prompt = "Describe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: " + self.user_prompt_begin_id = 28 + self.tokenizer_max_length = 256 + self.user_prompt_begin_id + self.default_sample_size = 128 + + def _get_messages( + self, + prompt: str | list[str] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + messages = [] + for each_prompt in prompt: + message = [ + { + "role": "user", + "content": self.system_prompt + each_prompt, + } + ] + message = self.tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True, enable_thinking=False + ) + messages.append(message) + return messages + + def _get_ovis_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + messages = self._get_messages(prompt) + batch_size = len(messages) + + tokens = self.tokenizer( + messages, + padding="max_length", + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = tokens.input_ids.to(device) + attention_mask = tokens.attention_mask.to(device) + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = outputs.last_hidden_state + prompt_embeds = prompt_embeds * attention_mask[..., None] + prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id :, :] + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + ): + r""" + + Args: + prompt (`str`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + if prompt_embeds is None: + prompt_embeds = self._get_ovis_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + text_ids[..., 1] = text_ids[..., 1] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids[..., 2] = text_ids[..., 2] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids = text_ids.to(device=device, dtype=dtype) + return prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = "", + guidance_scale: float = 5.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + guidance_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1 + ( + prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + if do_classifier_free_guidance: + ( + negative_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return OvisImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..176efe3adef6a792ba2e7fef194e2a0a9475fbfd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__init__.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"] + _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] + _import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] + _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"] + _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] + _import_structure["pipeline_pag_sana"] = ["SanaPAGPipeline"] + _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] + _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"] + _import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] + _import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_inpaint"] = ["StableDiffusionPAGInpaintPipeline"] + + _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] + _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] + _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline + from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline + from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline + from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline + from .pipeline_pag_kolors import KolorsPAGPipeline + from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline + from .pipeline_pag_sana import SanaPAGPipeline + from .pipeline_pag_sd import StableDiffusionPAGPipeline + from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline + from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline + from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline + from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline + from .pipeline_pag_sd_inpaint import StableDiffusionPAGInpaintPipeline + from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline + from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline + from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_controlnet_sd.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_controlnet_sd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5655e72d6b4644bd62698dcaa3758314255bb46 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_controlnet_sd.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_controlnet_sd_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_controlnet_sd_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0d533213cbb9a549378d2adb8e8c3887f0d8a66 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_controlnet_sd_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_hunyuandit.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_hunyuandit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5daa96c2cdaed067b1b3b750c8b4f182e2e4d06a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_hunyuandit.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_kolors.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_kolors.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c645328d7cb538fe3ef1388296328779692c45d5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_kolors.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sana.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sana.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..773bb450e37eb0e1a67a397569c56f4780a6d962 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sana.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sd_3.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sd_3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c780e7c4bb09664435a9e847f21cf0ea475efc4d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sd_3.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sd_inpaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sd_inpaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7289786c908c5f3c29ab6f64482b1d62e80078e9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/__pycache__/pipeline_pag_sd_inpaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pag_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pag_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41395ece742155a9f0a9003cf53e398ca8661c23 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pag_utils.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch +import torch.nn as nn + +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + PAGCFGIdentitySelfAttnProcessor2_0, + PAGIdentitySelfAttnProcessor2_0, +) +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PAGMixin: + r"""Mixin class for [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377v1).""" + + def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance): + r""" + Set the attention processor for the PAG layers. + """ + pag_attn_processors = self._pag_attn_processors + if pag_attn_processors is None: + raise ValueError( + "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters." + ) + + pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1] + + if hasattr(self, "unet"): + model: nn.Module = self.unet + else: + model: nn.Module = self.transformer + + def is_self_attn(module: nn.Module) -> bool: + r""" + Check if the module is self-attention module based on its name. + """ + return isinstance(module, Attention) and not module.is_cross_attention + + def is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + for layer_id in pag_applied_layers: + # for each PAG layer input, we find corresponding self-attention layers in the unet model + target_modules = [] + + for name, module in model.named_modules(): + # Identify the following simple cases: + # (1) Self Attention layer existing + # (2) Whether the module name matches pag layer id even partially + # (3) Make sure it's not a fake integral match if the layer_id ends with a number + # For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1" + if ( + is_self_attn(module) + and re.search(layer_id, name) is not None + and not is_fake_integral_match(layer_id, name) + ): + logger.debug(f"Applying PAG to layer: {name}") + target_modules.append(module) + + if len(target_modules) == 0: + raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}") + + for module in target_modules: + module.processor = pag_attn_proc + + def _get_pag_scale(self, t): + r""" + Get the scale factor for the perturbed attention guidance at timestep `t`. + """ + + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t) + if signal_scale < 0: + signal_scale = 0 + return signal_scale + else: + return self.pag_scale + + def _apply_perturbed_attention_guidance( + self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False + ): + r""" + Apply perturbed attention guidance to the noise prediction. + + Args: + noise_pred (torch.Tensor): The noise prediction tensor. + do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. + guidance_scale (float): The scale factor for the guidance term. + t (int): The current time step. + return_pred_text (bool): Whether to return the text noise prediction. + + Returns: + torch.Tensor | tuple[torch.Tensor, torch.Tensor]: The updated noise prediction tensor after applying + perturbed attention guidance and the text noise prediction. + """ + pag_scale = self._get_pag_scale(t) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + + pag_scale * (noise_pred_text - noise_pred_perturb) + ) + else: + noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) + noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + if return_pred_text: + return noise_pred, noise_pred_text + return noise_pred + + def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): + """ + Prepares the perturbed attention guidance for the PAG model. + + Args: + cond (torch.Tensor): The conditional input tensor. + uncond (torch.Tensor): The unconditional input tensor. + do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance. + + Returns: + torch.Tensor: The prepared perturbed attention guidance tensor. + """ + + cond = torch.cat([cond] * 2, dim=0) + + if do_classifier_free_guidance: + cond = torch.cat([uncond, cond], dim=0) + return cond + + def set_pag_applied_layers( + self, + pag_applied_layers: str | list[str], + pag_attn_processors: tuple[AttentionProcessor, AttentionProcessor] = ( + PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0(), + ), + ): + r""" + Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid. + + Args: + pag_applied_layers (`str` or `list[str]`): + One or more strings identifying the layer names, or a simple regex for matching multiple layers, where + PAG is to be applied. A few ways of expected usage are as follows: + - Single layers specified as - "blocks.{layer_index}" + - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...] + - Multiple layers as a block name - "mid" + - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})" + pag_attn_processors: + (`tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(), + PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention + processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second + attention processor is for PAG with CFG disabled (unconditional only). + """ + + if not hasattr(self, "_pag_attn_processors"): + self._pag_attn_processors = None + + if not isinstance(pag_applied_layers, list): + pag_applied_layers = [pag_applied_layers] + if pag_attn_processors is not None: + if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2: + raise ValueError("Expected a tuple of two attention processors") + + for i in range(len(pag_applied_layers)): + if not isinstance(pag_applied_layers[i], str): + raise ValueError( + f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}" + ) + + self.pag_applied_layers = pag_applied_layers + self._pag_attn_processors = pag_attn_processors + + @property + def pag_scale(self) -> float: + r"""Get the scale factor for the perturbed attention guidance.""" + return self._pag_scale + + @property + def pag_adaptive_scale(self) -> float: + r"""Get the adaptive scale factor for the perturbed attention guidance.""" + return self._pag_adaptive_scale + + @property + def do_pag_adaptive_scaling(self) -> bool: + r"""Check if the adaptive scaling is enabled for the perturbed attention guidance.""" + return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def do_perturbed_attention_guidance(self) -> bool: + r"""Check if the perturbed attention guidance is enabled.""" + return self._pag_scale > 0 and len(self.pag_applied_layers) > 0 + + @property + def pag_attn_processors(self) -> dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model + with the key as the name of the layer. + """ + + if self._pag_attn_processors is None: + return {} + + valid_attn_processors = {x.__class__ for x in self._pag_attn_processors} + + processors = {} + # We could have iterated through the self.components.items() and checked if a component is + # `ModelMixin` subclassed but that can include a VAE too. + if hasattr(self, "unet"): + denoiser_module = self.unet + elif hasattr(self, "transformer"): + denoiser_module = self.transformer + else: + raise ValueError("No denoiser module found.") + + for name, proc in denoiser_module.attn_processors.items(): + if proc.__class__ in valid_attn_processors: + processors[name] = proc + + return processors diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..807c42d21bb4f3bc858bcf25bb4a9d1528e888b3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -0,0 +1,1348 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting", + ... guidance_scale=7.5, + ... generator=generator, + ... image=canny_image, + ... pag_scale=10, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + controlnet_prompt_embeds = prompt_embeds + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + # 8. Denoising loop + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Inferred ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc2e882868cc2d2ddc9f83f3711af950921a298 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -0,0 +1,1552 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> import cv2 + >>> from diffusers import AutoPipelineForInpainting, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> from PIL import Image + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_canny_condition(image): + ... image = np.array(image) + ... image = cv2.Canny(image, 100, 200) + ... image = image[:, :, None] + ... image = np.concatenate([image, image, image], axis=2) + ... image = Image.fromarray(image) + ... return image + + + >>> control_image = make_canny_condition(init_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... controlnet=controlnet, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... pag_scale=0.3, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionControlNetPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting > + ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)) + as well as > default text-to-image Stable Diffusion checkpoints > + ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)). + Default text-to-image > Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned + on those, such as > + [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint.StableDiffusionControlNetInpaintPipeline.prepare_control_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + crops_coords, + resize_mode, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.5, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, + `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to be used as the starting point. For both + NumPy array and PyTorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a NumPy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, + `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, NumPy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a NumPy array or PyTorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for PyTorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for NumPy array, it would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, + W, 1)`, or `(H, W)`. + control_image (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, + `list[list[torch.Tensor]]`, or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 0.5): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + mask_image, + height, + width, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare control image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 7.1 Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 7.2 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.3 Prepare embeddings + # ip-adapter + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # control image + control_images = control_image if isinstance(control_image, list) else [control_image] + for i, single_control_image in enumerate(control_images): + if self.do_classifier_free_guidance: + single_control_image = single_control_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_control_image = self._prepare_perturbed_attention_guidance( + single_control_image, single_control_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_control_image = torch.cat([single_control_image] * 2) + single_control_image = single_control_image.to(device) + control_images[i] = single_control_image + + control_image = control_images if isinstance(control_image, list) else control_images[0] + controlnet_prompt_embeds = prompt_embeds + + # 7.4 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=False, + return_dict=False, + ) + + # concat latents, mask, masked_image_latents in the channel dimension + if num_channels_unet == 9: + first_dim_size = latent_model_input.shape[0] + # Ensure mask and masked_image_latents have the right dimensions + if mask.shape[0] < first_dim_size: + repeat_factor = (first_dim_size + mask.shape[0] - 1) // mask.shape[0] + mask = mask.repeat(repeat_factor, 1, 1, 1)[:first_dim_size] + if masked_image_latents.shape[0] < first_dim_size: + repeat_factor = ( + first_dim_size + masked_image_latents.shape[0] - 1 + ) // masked_image_latents.shape[0] + masked_image_latents = masked_image_latents.repeat(repeat_factor, 1, 1, 1)[:first_dim_size] + # Perform the concatenation + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e2c03faed76357f1d691bcd500c065f81a927d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -0,0 +1,1621 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + >>> negative_prompt = "low quality, bad quality, sketches" + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" + ... ) + + >>> # initialize the models and pipeline + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> # get canny image + >>> image = np.array(image) + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # generate image + >>> image = pipe( + ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image, pag_scale=0.3 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLControlNetPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: str | list[str] = "mid", # ["down.block_2", "up.block_1.attentions_0"], "mid" + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + images = image if isinstance(image, list) else [image] + for i, single_image in enumerate(images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + images[i] = single_image + + image = images if isinstance(image, list) else images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=False, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..8967e50251b9b89e544c45130c2390aa86a7e649 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -0,0 +1,1685 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # pip install accelerate transformers safetensors diffusers + + >>> import torch + >>> import numpy as np + >>> from PIL import Image + + >>> from transformers import DPTFeatureExtractor, DPTForDepthEstimation + >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL + >>> from diffusers.utils import load_image + + >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") + >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") + >>> controlnet = ControlNetModel.from_pretrained( + ... "diffusers/controlnet-depth-sdxl-1.0-small", + ... variant="fp16", + ... use_safetensors="True", + ... torch_dtype=torch.float16, + ... ) + >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) + >>> pipe = StableDiffusionXLControlNetPAGImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... controlnet=controlnet, + ... vae=vae, + ... variant="fp16", + ... use_safetensors=True, + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe.enable_model_cpu_offload() + + + >>> def get_depth_map(image): + ... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") + ... with torch.no_grad(), torch.autocast("cuda"): + ... depth_map = depth_estimator(image).predicted_depth + + ... depth_map = torch.nn.functional.interpolate( + ... depth_map.unsqueeze(1), + ... size=(1024, 1024), + ... mode="bicubic", + ... align_corners=False, + ... ) + ... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + ... depth_map = (depth_map - depth_min) / (depth_max - depth_min) + ... image = torch.cat([depth_map] * 3, dim=1) + ... image = image.permute(0, 2, 3, 1).cpu().numpy()[0] + ... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) + ... return image + + + >>> prompt = "A robot, 4k photo" + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ).resize((1024, 1024)) + >>> controlnet_conditioning_scale = 0.5 # recommended for good generalization + >>> depth_image = get_depth_map(image) + + >>> images = pipe( + ... prompt, + ... image=image, + ... control_image=depth_image, + ... strength=0.99, + ... num_inference_steps=50, + ... controlnet_conditioning_scale=controlnet_conditioning_scale, + ... ).images + >>> images[0].save(f"robot_cat.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionXLControlNetPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `list[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel | list[ControlNetModel] | tuple[ControlNetModel] | MultiControlNetModel, + scheduler: KarrasDiffusionSchedulers, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl_img2img.StableDiffusionXLControlNetImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + controlnet_conditioning_scale: float | list[float] = 0.8, + guess_mode: bool = False, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accept + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also + be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in + init, images must be passed as a list such that each element of the list can be correctly batched for + input to a single controlnet. + height (`int`, *optional*, defaults to the size of control_image): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to the size of control_image): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple` containing the output images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + control_image, + strength, + num_inference_steps, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + prompt_2, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image and controlnet_conditioning_image + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + True, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(control_image, list): + original_size = original_size or control_image[0].shape[-2:] + else: + original_size = original_size or control_image.shape[-2:] + target_size = target_size or (height, width) + + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + control_images = control_image if isinstance(control_image, list) else [control_image] + for i, single_image in enumerate(control_images): + if self.do_classifier_free_guidance: + single_image = single_image.chunk(2)[0] + + if self.do_perturbed_attention_guidance: + single_image = self._prepare_perturbed_attention_guidance( + single_image, single_image, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + single_image = torch.cat([single_image] * 2) + single_image = single_image.to(device) + control_images[i] = single_image + + control_image = control_images if isinstance(control_image, list) else control_images[0] + + if ip_adapter_image_embeds is not None: + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + control_model_input = latent_model_input + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + guess_mode=False, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + empty_device_cache() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py new file mode 100644 index 0000000000000000000000000000000000000000..15ac665acd2b0f3ff75886f21695b9af087e513f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -0,0 +1,964 @@ +# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, HunyuanDiT2DModel +from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHunyuanAttnProcessor2_0 +from ...models.embeddings import get_2d_rotary_pos_embed +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import DDPMScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=[14], + ... ).to("cuda") + + >>> # prompt = "an astronaut riding a horse" + >>> prompt = "一个宇航员在骑马" + >>> image = pipe(prompt, guidance_scale=4, pag_scale=3).images[0] + ``` +""" + +STANDARD_RATIO = np.array( + [ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 + ] +) +STANDARD_SHAPE = [ + [(1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE] +SUPPORTED_SHAPE = [ + (1024, 1024), + (1280, 1280), # 1:1 + (1024, 768), + (1152, 864), + (1280, 960), # 4:3 + (768, 1024), + (864, 1152), + (960, 1280), # 3:4 + (1280, 768), # 16:9 + (768, 1280), # 9:16 +] + + +def map_to_standard_shapes(target_width, target_height): + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for English/Chinese-to-image generation using HunyuanDiT and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + ourselves) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use + `sdxl-vae-fp16-fix`. + text_encoder (`~transformers.BertModel`, `~transformers.CLIPTextModel` | None): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + HunyuanDiT uses a fine-tuned [bilingual CLIP]. + tokenizer (`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer` | None): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + transformer ([`HunyuanDiT2DModel`]): + The HunyuanDiT model designed by Tencent Hunyuan. + text_encoder_2 (`T5EncoderModel`): + The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`DDPMScheduler`]): + A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "safety_checker", + "feature_extractor", + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: BertModel, + tokenizer: BertTokenizer, + transformer: HunyuanDiT2DModel, + scheduler: DDPMScheduler, + safety_checker: StableDiffusionSafetyChecker | None = None, + feature_extractor: CLIPImageProcessor | None = None, + requires_safety_checker: bool = True, + text_encoder_2: T5EncoderModel | None = None, + tokenizer_2: T5Tokenizer | None = None, + pag_applied_layers: str | list[str] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + text_encoder_2=text_encoder_2, + ) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGHunyuanAttnProcessor2_0(), PAGHunyuanAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + device: torch.device = None, + dtype: torch.dtype = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int | None = None, + text_encoder_index: int = 0, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. + """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] + + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = 77 + if text_encoder_index == 1: + max_length = 256 + else: + max_length = max_sequence_length + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_2: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_2: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + prompt_attention_mask_2: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask_2: torch.Tensor | None = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int, dict], None] + | PipelineCallback + | MultiPipelineCallbacks + | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = (1024, 1024), + target_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + use_resolution_binning: bool = True, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable[[int, int], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function or a list of callback functions to be called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + A list of tensor inputs that should be passed to the callback function. If not defined, all tensor + inputs will be passed. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise + Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 + original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original size of the image. Used to calculate the time ids. + target_size (`tuple[int, int]`, *optional*): + The target size of the image. Used to calculate the time ids. + crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`): + The top left coordinates of the crop. Used to calculate the time ids. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest + standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960, + 768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + height = int((height // 16) * 16) + width = int((width // 16) * 16) + + if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE: + width, height = map_to_standard_shapes(width, height) + height = int(height) + width = int(width) + logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=77, + text_encoder_index=0, + ) + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + max_sequence_length=256, + text_encoder_index=1, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.inner_dim // self.transformer.num_heads, + grid_crops_coords, + (grid_height, grid_width), + device=device, + output_type="pt", + ) + + style = torch.tensor([0], device=device) + + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + prompt_embeds_2 = self._prepare_perturbed_attention_guidance( + prompt_embeds_2, negative_prompt_embeds_2, self.do_classifier_free_guidance + ) + prompt_attention_mask_2 = self._prepare_perturbed_attention_guidance( + prompt_attention_mask_2, negative_prompt_attention_mask_2, self.do_classifier_free_guidance + ) + add_time_ids = torch.cat([add_time_ids] * 3, dim=0) + style = torch.cat([style] * 3, dim=0) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_kolors.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_kolors.py new file mode 100644 index 0000000000000000000000000000000000000000..4f138d91d9c6d1529e9627925c9def05a3184f38 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -0,0 +1,1129 @@ +# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..kolors.pipeline_output import KolorsPipelineOutput +from ..kolors.text_encoder import ChatGLMModel +from ..kolors.tokenizer import ChatGLMTokenizer +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "Kwai-Kolors/Kolors-diffusers", + ... variant="fp16", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=["down.block_2.attentions_1", "up.block_0.attentions_1"], + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = ( + ... "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'" + ... ) + >>> image = pipe(prompt, guidance_scale=5.5, pag_scale=1.5).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class KolorsPAGPipeline( + DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin, PAGMixin +): + r""" + Pipeline for text-to-image generation using Kolors. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`ChatGLMModel`]): + Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b). + tokenizer (`ChatGLMTokenizer`): + Tokenizer of class + [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `Kwai-Kolors/Kolors-diffusers`. + pag_applied_layers (`str` or `list[str]``, *optional*, defaults to `"mid"`): + Set the transformer attention layers where to apply the perturbed attention guidance. Can be a string or a + list of strings with "down", "mid", "up", a whole transformer block or specific transformer block attention + layers, e.g.: + ["mid"] ["down", "mid"] ["down", "mid", "up.block_1"] ["down", "mid", "up.block_1.attentions_0", + "up.block_1.attentions_1"] + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: ChatGLMModel, + tokenizer: ChatGLMTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = False, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 256, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + """ + # from IPython import embed; embed(); exit() + device = device or self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer] + text_encoders = [self.text_encoder] + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs["input_ids"], + attention_mask=text_inputs["attention_mask"], + position_ids=text_inputs["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = prompt_embeds_list[0] + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=uncond_input["input_ids"], + attention_mask=uncond_input["attention_mask"], + position_ids=uncond_input["position_ids"], + output_hidden_states=True, + ) + + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() + # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size] + negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = negative_prompt_embeds_list[0] + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.check_inputs + def check_inputs( + self, + prompt, + num_inference_steps, + height, + width, + negative_prompt=None, + prompt_embeds=None, + pooled_prompt_embeds=None, + negative_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints + that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + num_inference_steps, + height, + width, + negative_prompt, + prompt_embeds, + pooled_prompt_embeds, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return KolorsPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7e8ab9814f890b6e3e2b8c0306da87c8a8d53a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -0,0 +1,882 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", + ... torch_dtype=torch.float16, + ... pag_applied_layers=["blocks.14"], + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A small cactus with a happy face in the Sahara desert" + >>> image = pipe(prompt, pag_scale=4.0, guidance_scale=1.0).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation + using PixArt-Sigma. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + pag_applied_layers: str | list[str] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, do_classifier_free_guidance + ) + elif do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep + ) + elif do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sana.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sana.py new file mode 100644 index 0000000000000000000000000000000000000000..71861f36647717c37e7f18c4a1854e4f68cfe6ed --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -0,0 +1,978 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0 +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaPAGPipeline + + >>> pipe = SanaPAGPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", + ... pag_applied_layers=["transformer_blocks.8"], + ... torch_dtype=torch.float32, + ... ) + >>> pipe.to("cuda") + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) + + >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaPAGPipeline(DiffusionPipeline, PAGMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). This pipeline + supports the use of [Perturbed Attention Guidance + (PAG)](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + pag_applied_layers: str | list[str] = "transformer_blocks.0", + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 8 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers( + pag_applied_layers, + pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()), + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0][:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + if self.transformer is not None: + dtype = self.transformer.dtype + elif self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + prompt_attention_mask = self._prepare_perturbed_attention_guidance( + prompt_attention_mask, negative_prompt_attention_mask, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..26ea717556c5f8cf5f59c4b2504cbead71279316 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -0,0 +1,1082 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + None, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_3.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fbef29b699bdfe31cc676b2b0742876f5ba0f0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -0,0 +1,998 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0 +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... pag_applied_layers=["blocks.13"], + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt, guidance_scale=5.0, pag_scale=0.7).images[0] + >>> image.save("sd3_pag.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation + using Stable Diffusion 3. + + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + pag_applied_layers: str | list[str] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale # + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + pooled_prompt_embeds = self._prepare_perturbed_attention_guidance( + pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..84b727dc061350d22be13f025c2c2d773dcbe2f2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -0,0 +1,1062 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...models.attention_processor import PAGCFGJointAttnProcessor2_0, PAGJointAttnProcessor2_0 +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3PAGImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusion3PAGImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", + ... torch_dtype=torch.float16, + ... pag_applied_layers=["blocks.13"], + ... ) + >>> pipe.to("cuda") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + >>> init_image = load_image(url).convert("RGB") + >>> image = pipe(prompt, image=init_image, guidance_scale=5.0, pag_scale=0.7).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, PAGMixin): + r""" + [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for image-to-image generation + using Stable Diffusion 3. + + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + pag_applied_layers: str | list[str] = "blocks.1", # 1st transformer block + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + self.set_pag_applied_layers( + pag_applied_layers, pag_attn_processors=(PAGCFGJointAttnProcessor2_0(), PAGJointAttnProcessor2_0()) + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if image.shape[1] == self.vae.config.latent_channels: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + latents = init_latents.to(device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + image: PipelineImageInput = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + pooled_prompt_embeds = self._prepare_perturbed_attention_guidance( + pooled_prompt_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Preprocess image + image = self.image_processor.preprocess(image, height=height, width=width) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.transformer.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.transformer.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py new file mode 100644 index 0000000000000000000000000000000000000000..62d1c912283f7258741c846d3a4df5110768bda2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -0,0 +1,878 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..animatediff.pipeline_output import AnimateDiffPipelineOutput +from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AnimateDiffPAGPipeline, MotionAdapter, DDIMScheduler + >>> from diffusers.utils import export_to_gif + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> motion_adapter_id = "guoyww/animatediff-motion-adapter-v1-5-2" + >>> motion_adapter = MotionAdapter.from_pretrained(motion_adapter_id) + >>> scheduler = DDIMScheduler.from_pretrained( + ... model_id, subfolder="scheduler", beta_schedule="linear", steps_offset=1, clip_sample=False + ... ) + >>> pipe = AnimateDiffPAGPipeline.from_pretrained( + ... model_id, + ... motion_adapter=motion_adapter, + ... scheduler=scheduler, + ... pag_applied_layers=["mid"], + ... torch_dtype=torch.float16, + ... ).to("cuda") + + >>> video = pipe( + ... prompt="car, futuristic cityscape with neon lights, street, no human", + ... negative_prompt="low quality, bad quality", + ... num_inference_steps=25, + ... guidance_scale=6.0, + ... pag_scale=3.0, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + + >>> export_to_gif(video, "animatediff_pag.gif") + ``` +""" + + +class AnimateDiffPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FreeInitMixin, + AnimateDiffFreeNoiseMixin, + PAGMixin, +): + r""" + Pipeline for text-to-video generation using + [AnimateDiff](https://huggingface.co/docs/diffusers/en/api/pipelines/animatediff) and [Perturbed Attention + Guidance](https://huggingface.co/docs/diffusers/en/using-diffusers/pag). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + motion_adapter: MotionAdapter, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + pag_applied_layers: str | list[str] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"] + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, decode_chunk_size: int = 16): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + video = [] + for i in range(0, latents.shape[0], decode_chunk_size): + batch_latents = latents[i : i + decode_chunk_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pia.pipeline_pia.PIAPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + num_frames: int | None = 16, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + decode_chunk_size: int = 16, + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat( + [latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0]) + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (video,) + + return AnimateDiffPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..822483eca995e43da9f3c8b68cb71dfc6ae44c94 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py @@ -0,0 +1,1112 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..0f6fbbd9ae1607402a457670d78688ad77544beb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py @@ -0,0 +1,1376 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .pag_utils import PAGMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForInpainting + + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True + ... ) + >>> pipe = pipe.to("cuda") + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, + ... image=init_image, + ... mask_image=mask_image, + ... strength=0.8, + ... num_inference_steps=50, + ... guidance_scale=guidance_scale, + ... generator=generator, + ... pag_scale=pag_scale, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + pag_applied_layers: str | list[str] = "mid", + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask_image, + height, + width, + strength, + None, + None, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 9 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 9.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": ip_adapter_image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) + else: + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs + )[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..2987c90626ef62f88b3db276547d25537a589a5e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -0,0 +1,1344 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForText2Image + + >>> pipe = AutoPipelineForText2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"],["down.block_1"],["up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, negative_add_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance, perturbed-attention guidance, or both + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..433b9edc69b7c895c2aa645fc91480cd654957f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -0,0 +1,1539 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForImage2Image.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", + ... torch_dtype=torch.float16, + ... enable_pag=True, + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image, pag_scale=0.3).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + num_inference_steps, + None, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..9caf50e5e333ec4ce9295e5a597a094c0d749087 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -0,0 +1,1770 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +from .pag_utils import PAGMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoPipelineForInpainting + >>> from diffusers.utils import load_image + + >>> pipe = AutoPipelineForInpainting.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... enable_pag=True, + ... ) + >>> pipe.to("cuda") + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, + ... image=init_image, + ... mask_image=mask_image, + ... num_inference_steps=50, + ... strength=0.80, + ... pag_scale=0.3, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPAGInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, + PAGMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "add_neg_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + pag_applied_layers: str | list[str] = "mid", # ["mid"], ["down.block_1", "up.block_0.attentions_0"] + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.set_pag_applied_layers(pag_applied_layers) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + pag_scale: float = 3.0, + pag_adaptive_scale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + pag_scale (`float`, *optional*, defaults to 3.0): + The scale factor for the perturbed attention guidance. If it is set to 0.0, the perturbed attention + guidance will not be used. + pag_adaptive_scale (`float`, *optional*, defaults to 0.0): + The adaptive scale factor for the perturbed attention guidance. If it is set to 0.0, `pag_scale` is + used. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + None, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scale = pag_adaptive_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + if self.do_perturbed_attention_guidance: + if self.do_classifier_free_guidance: + mask, _ = mask.chunk(2) + masked_image_latents, _ = masked_image_latents.chunk(2) + mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance) + masked_image_latents = self._prepare_perturbed_attention_guidance( + masked_image_latents, masked_image_latents, self.do_classifier_free_guidance + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_perturbed_attention_guidance: + prompt_embeds = self._prepare_perturbed_attention_guidance( + prompt_embeds, negative_prompt_embeds, self.do_classifier_free_guidance + ) + add_text_embeds = self._prepare_perturbed_attention_guidance( + add_text_embeds, negative_pooled_prompt_embeds, self.do_classifier_free_guidance + ) + add_time_ids = self._prepare_perturbed_attention_guidance( + add_time_ids, add_neg_time_ids, self.do_classifier_free_guidance + ) + + elif self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + for i, image_embeds in enumerate(ip_adapter_image_embeds): + negative_image_embeds = None + if self.do_classifier_free_guidance: + negative_image_embeds, image_embeds = image_embeds.chunk(2) + + if self.do_perturbed_attention_guidance: + image_embeds = self._prepare_perturbed_attention_guidance( + image_embeds, negative_image_embeds, self.do_classifier_free_guidance + ) + elif self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + image_embeds = image_embeds.to(device) + ip_adapter_image_embeds[i] = image_embeds + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + if self.do_perturbed_attention_guidance: + original_attn_proc = self.unet.attn_processors + self._set_pag_attn_processor( + pag_applied_layers=self.pag_applied_layers, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = ip_adapter_image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_perturbed_attention_guidance: + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True + ) + elif self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_perturbed_attention_guidance: + init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2) + else: + init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if self.do_perturbed_attention_guidance: + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2906b540c6eeb6f463340c8f856611e6c5dbd2f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/__init__.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["image_encoder"] = ["PaintByExampleImageEncoder"] + _import_structure["pipeline_paint_by_example"] = ["PaintByExamplePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .image_encoder import PaintByExampleImageEncoder + from .pipeline_paint_by_example import PaintByExamplePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/image_encoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..74c575ed8653d110ca300e9198a5b13070659adf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/image_encoder.py @@ -0,0 +1,67 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from transformers import CLIPPreTrainedModel, CLIPVisionModel + +from ...models.attention import BasicTransformerBlock +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class PaintByExampleImageEncoder(CLIPPreTrainedModel): + def __init__(self, config, proj_size=None): + super().__init__(config) + self.proj_size = proj_size or getattr(config, "projection_dim", 768) + + self.model = CLIPVisionModel(config) + self.mapper = PaintByExampleMapper(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + self.proj_out = nn.Linear(config.hidden_size, self.proj_size) + + # uncondition for scaling + self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) + + def forward(self, pixel_values, return_uncond_vector=False): + clip_output = self.model(pixel_values=pixel_values) + latent_states = clip_output.pooler_output + latent_states = self.mapper(latent_states[:, None]) + latent_states = self.final_layer_norm(latent_states) + latent_states = self.proj_out(latent_states) + if return_uncond_vector: + return latent_states, self.uncond_vector + + return latent_states + + +class PaintByExampleMapper(nn.Module): + def __init__(self, config): + super().__init__() + num_layers = (config.num_hidden_layers + 1) // 5 + hid_size = config.hidden_size + num_heads = 1 + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states): + for block in self.blocks: + hidden_states = block(hidden_states) + + return hidden_states diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7dbaa720e570edbed4acf497cb01c9a9339935 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -0,0 +1,633 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .image_encoder import PaintByExampleImageEncoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (np.array | PIL.Image | torch.Tensor): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Batched mask + if mask.shape[0] == image.shape[0]: + mask = mask.unsqueeze(1) + else: + mask = mask.unsqueeze(0) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + assert mask.shape[1] == 1, "Mask image must have a single channel" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # paint-by-example inverses the mask + mask = 1 - mask + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = [image] + + image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, PIL.Image.Image): + mask = [mask] + + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + + # paint-by-example inverses the mask + mask = 1 - mask + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * mask + + return mask, masked_image + + +class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + _last_supported_version = "0.33.1" + r""" + > [!WARNING] > 🧪 This is an experimental feature! + + Pipeline for image-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`PaintByExampleImageEncoder`]): + Encodes the example input image. The `unet` is conditioned on the example image instead of a text prompt. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + + """ + + # TODO: feature_extractor is required to encode initial images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + + model_cpu_offload_seq = "unet->vae" + _exclude_from_cpu_offload = ["image_encoder"] + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: PaintByExampleImageEncoder, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + example_image: torch.Tensor | PIL.Image.Image, + image: torch.Tensor | PIL.Image.Image, + mask_image: torch.Tensor | PIL.Image.Image, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + example_image (`torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]`): + An example image to guide image generation. + image (`torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]`): + `Image` or tensor representing an image batch to be inpainted (parts of the image are masked out with + `mask_image` and repainted according to `prompt`). + mask_image (`torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]`): + `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted, + while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel + (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the + expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Example: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + >>> from diffusers import PaintByExamplePipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" + ... ) + >>> mask_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" + ... ) + >>> example_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + >>> example_image = download_image(example_url).resize((512, 512)) + + >>> pipe = PaintByExamplePipeline.from_pretrained( + ... "Fantasy-Studio/Paint-by-Example", + ... torch_dtype=torch.float16, + ... ) + >>> pipe = pipe.to("cuda") + + >>> image = pipe(image=init_image, mask_image=mask_image, example_image=example_image).images[0] + >>> image + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + height, width = masked_image.shape[-2:] + + # 3. Check inputs + self.check_inputs(example_image, height, width, callback_steps) + + # 4. Encode input image + image_embeddings = self._encode_image( + example_image, device, num_images_per_prompt, do_classifier_free_guidance + ) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + image_embeddings.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + self.maybe_free_model_hooks() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16e8004966e58387f7dba2f6ff3175575fe0abee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pia"] = ["PIAPipeline", "PIAPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .pipeline_pia import PIAPipeline, PIAPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49b7314e360cf17fa06b4f58a6cb0e3bfe26ed20 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__pycache__/pipeline_pia.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__pycache__/pipeline_pia.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8151905b6ca4f6d62f5abeb02ced09610256fbe9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/__pycache__/pipeline_pia.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/pipeline_pia.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/pipeline_pia.py new file mode 100644 index 0000000000000000000000000000000000000000..d108deb9c5dacc013e2edd80f5ba1fb1ff7fa0eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pia/pipeline_pia.py @@ -0,0 +1,956 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unets.unet_motion_model import MotionAdapter +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..free_init_utils import FreeInitMixin +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter") + >>> pipe = PIAPipeline.from_pretrained( + ... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + >>> image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true" + ... ) + >>> image = image.resize((512, 512)) + >>> prompt = "cat in a hat" + >>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted" + >>> generator = torch.Generator("cpu").manual_seed(0) + >>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator) + >>> frames = output.frames[0] + >>> export_to_gif(frames, "pia-animation.gif") + ``` +""" + +RANGE_list = [ + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion + [1.0, 0.9, 0.85, 0.85, 0.85, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.85, 0.85, 0.9, 1.0], # Loop + [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8, 0.8, 1.0], # Loop + [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5, 0.6, 0.7, 0.7, 0.7, 0.7, 0.8, 1.0], # Loop + [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion + [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion + [0.5, 0.2], # Style Transfer Large Motion +] + + +def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): + assert num_frames > 0, "video_length should be greater than 0" + + assert num_frames > cond_frame, "video_length should be greater than cond_frame" + + range_list = RANGE_list + + assert motion_scale < len(range_list), f"motion_scale type{motion_scale} not implemented" + + coef = range_list[motion_scale] + coef = coef + ([coef[-1]] * (num_frames - len(coef))) + + order = [abs(i - cond_frame) for i in range(num_frames)] + coef = [coef[order[i]] for i in range(num_frames)] + + return coef + + +@dataclass +class PIAPipelineOutput(BaseOutput): + r""" + Output class for PIAPipeline. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, NumPy array of + shape `(batch_size, num_frames, channels, height, width, Torch tensor of shape `(batch_size, num_frames, + channels, height, width)`. + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] + + +class PIAPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, + FreeInitMixin, +): + _last_supported_version = "0.33.1" + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder", "motion_adapter"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel | UNetMotionModel, + scheduler: DDIMScheduler + | PNDMScheduler + | LMSDiscreteScheduler + | EulerDiscreteScheduler + | EulerAncestralDiscreteScheduler + | DPMSolverMultistepScheduler, + motion_adapter: MotionAdapter | None = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + if isinstance(unet, UNet2DConditionModel): + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_masked_condition( + self, + image, + batch_size, + num_channels_latents, + num_frames, + height, + width, + dtype, + device, + generator, + motion_scale=0, + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + _, _, _, scaled_height, scaled_width = shape + + image = self.video_processor.preprocess(image) + image = image.to(device, dtype) + + if isinstance(generator, list): + image_latent = [ + self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size) + ] + image_latent = torch.cat(image_latent, dim=0) + else: + image_latent = self.vae.encode(image).latent_dist.sample(generator) + + image_latent = image_latent.to(device=device, dtype=dtype) + image_latent = torch.nn.functional.interpolate(image_latent, size=[scaled_height, scaled_width]) + image_latent_padding = image_latent.clone() * self.vae.config.scaling_factor + + mask = torch.zeros((batch_size, 1, num_frames, scaled_height, scaled_width)).to(device=device, dtype=dtype) + mask_coef = prepare_mask_coef_by_statistics(num_frames, 0, motion_scale) + masked_image = torch.zeros(batch_size, 4, num_frames, scaled_height, scaled_width).to( + device=device, dtype=self.unet.dtype + ) + for f in range(num_frames): + mask[:, :, f, :, :] = mask_coef[f] + masked_image[:, :, f, :, :] = image_latent_padding.clone() + + mask = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask + masked_image = torch.cat([masked_image] * 2) if self.do_classifier_free_guidance else masked_image + + return mask, masked_image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + strength: float = 1.0, + num_frames: int | None = 16, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + motion_scale: int = 0, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to be used for video generation. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + motion_scale: (`int`, *optional*, defaults to 0): + Parameter that controls the amount and type of motion that is added to the image. Increasing the value + increases the amount of motion, while specific ranges of values control the type of motion that is + added. Must be between 0 and 8. Set between 0-2 to only increase the amount of motion. Set between 3-5 + to create looping motion. Set between 6-8 to perform motion with image style transfer. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + 4, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents=latents, + ) + mask, masked_image = self.prepare_masked_condition( + image, + batch_size * num_videos_per_prompt, + 4, + num_frames=num_frames, + height=height, + width=width, + dtype=self.unet.dtype, + device=device, + generator=generator, + motion_scale=motion_scale, + ) + if strength < 1.0: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + latents = self.scheduler.add_noise(masked_image[0], noise, latent_timestep) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 8. Denoising loop + num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 + for free_init_iter in range(num_free_init_iters): + if self.free_init_enabled: + latents, timesteps = self._apply_free_init( + latents, free_init_iter, num_inference_steps, device, latents.dtype, generator + ) + + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + with self.progress_bar(total=self._num_timesteps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return PIAPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb870fc8589aeff366429c182736b2ddd6ce215 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__init__.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"] + _import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, + PixArtAlphaPipeline, + ) + from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1a163b4480769474a4bb58055070320e258608 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/pipeline_pixart_alpha.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/pipeline_pixart_alpha.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a5390143a7977c2e95d74e2150a42c86b488097 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/pipeline_pixart_alpha.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/pipeline_pixart_sigma.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/pipeline_pixart_sigma.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1257a6611ca936500162ee641181d04d539789 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/__pycache__/pipeline_pixart_sigma.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py new file mode 100644 index 0000000000000000000000000000000000000000..604e51d885834e507b852dfcca588260b2907178 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -0,0 +1,980 @@ +# Copyright 2025 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0], +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtAlphaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as + [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS/blob/main/transformer/config.json#L2) + in the config, but the mismatch can be ignored. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1] + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py new file mode 100644 index 0000000000000000000000000000000000000000..286695aa8eb98d49ef40f6dfc8349496fac7df67 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -0,0 +1,910 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...image_processor import PixArtImageProcessor +from ...models import AutoencoderKL, PixArtTransformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .pipeline_pixart_alpha import ( + ASPECT_RATIO_256_BIN, + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_2048_BIN = { + "0.25": [1024.0, 4096.0], + "0.26": [1024.0, 3968.0], + "0.27": [1024.0, 3840.0], + "0.28": [1024.0, 3712.0], + "0.32": [1152.0, 3584.0], + "0.33": [1152.0, 3456.0], + "0.35": [1152.0, 3328.0], + "0.4": [1280.0, 3200.0], + "0.42": [1280.0, 3072.0], + "0.48": [1408.0, 2944.0], + "0.5": [1408.0, 2816.0], + "0.52": [1408.0, 2688.0], + "0.57": [1536.0, 2688.0], + "0.6": [1536.0, 2560.0], + "0.68": [1664.0, 2432.0], + "0.72": [1664.0, 2304.0], + "0.78": [1792.0, 2304.0], + "0.82": [1792.0, 2176.0], + "0.88": [1920.0, 2176.0], + "0.94": [1920.0, 2048.0], + "1.0": [2048.0, 2048.0], + "1.07": [2048.0, 1920.0], + "1.13": [2176.0, 1920.0], + "1.21": [2176.0, 1792.0], + "1.29": [2304.0, 1792.0], + "1.38": [2304.0, 1664.0], + "1.46": [2432.0, 1664.0], + "1.67": [2560.0, 1536.0], + "1.75": [2688.0, 1536.0], + "2.0": [2816.0, 1408.0], + "2.09": [2944.0, 1408.0], + "2.4": [3072.0, 1280.0], + "2.5": [3200.0, 1280.0], + "2.89": [3328.0, 1152.0], + "3.0": [3456.0, 1152.0], + "3.11": [3584.0, 1152.0], + "3.62": [3712.0, 1024.0], + "3.75": [3840.0, 1024.0], + "3.88": [3968.0, 1024.0], + "4.0": [4096.0, 1024.0], +} + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtSigmaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too. + >>> pipe = PixArtSigmaPipeline.from_pretrained( + ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16 + ... ) + >>> # Enable memory optimizations. + >>> # pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtSigmaPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Sigma. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as + [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/blob/main/transformer/config.json#L2) + in the config, but the mismatch can be ignored. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int | None = None, + width: int | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + **kwargs, + ) -> ImagePipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 300): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + is_npu = latent_model_input.device.type == "npu" + if isinstance(current_timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad2948e92e04823dc6f2b0d14cf4178860ef383f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/__init__.py @@ -0,0 +1,74 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["PRXPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_prx"] = ["PRXPipeline"] + +# Wrap T5GemmaEncoder to pass config.encoder (T5GemmaModuleConfig) instead of the +# composite T5GemmaConfig, which lacks flat attributes expected by T5GemmaEncoder.__init__. +try: + if is_transformers_available(): + import transformers + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder as _T5GemmaEncoder + + class T5GemmaEncoder(_T5GemmaEncoder): + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + if "config" not in kwargs: + from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig + + config = T5GemmaConfig.from_pretrained(pretrained_model_name_or_path) + if hasattr(config, "encoder"): + kwargs["config"] = config.encoder + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder + if not hasattr(transformers, "T5GemmaEncoder"): + transformers.T5GemmaEncoder = T5GemmaEncoder +except ImportError: + pass + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import PRXPipelineOutput + from .pipeline_prx import PRXPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..81f92c294735136779753592cf0b81de111abfff --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class PRXPipelineOutput(BaseOutput): + """ + Output class for PRX pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/pipeline_prx.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/pipeline_prx.py new file mode 100644 index 0000000000000000000000000000000000000000..e14815b91c41459973f12e5ba98256e723f398d9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/prx/pipeline_prx.py @@ -0,0 +1,801 @@ +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable + +import torch +from transformers import ( + AutoTokenizer, + GemmaTokenizerFast, + T5TokenizerFast, +) +from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + +from diffusers.image_processor import PixArtImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_ftfy_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + + +if is_ftfy_available(): + import ftfy + +DEFAULT_RESOLUTION = 512 + +ASPECT_RATIO_256_BIN = { + "0.46": [160, 352], + "0.6": [192, 320], + "0.78": [224, 288], + "1.0": [256, 256], + "1.29": [288, 224], + "1.67": [320, 192], + "2.2": [352, 160], +} + +ASPECT_RATIO_512_BIN = { + "0.5": [352, 704], + "0.57": [384, 672], + "0.6": [384, 640], + "0.68": [416, 608], + "0.78": [448, 576], + "0.88": [480, 544], + "1.0": [512, 512], + "1.13": [544, 480], + "1.29": [576, 448], + "1.46": [608, 416], + "1.67": [640, 384], + "1.75": [672, 384], + "2.0": [704, 352], +} + +ASPECT_RATIO_1024_BIN = { + "0.49": [704, 1440], + "0.52": [736, 1408], + "0.53": [736, 1376], + "0.57": [768, 1344], + "0.59": [768, 1312], + "0.62": [800, 1280], + "0.67": [832, 1248], + "0.68": [832, 1216], + "0.78": [896, 1152], + "0.83": [928, 1120], + "0.94": [992, 1056], + "1.0": [1024, 1024], + "1.06": [1056, 992], + "1.13": [1088, 960], + "1.21": [1120, 928], + "1.29": [1152, 896], + "1.37": [1184, 864], + "1.46": [1216, 832], + "1.5": [1248, 832], + "1.71": [1312, 768], + "1.75": [1344, 768], + "1.87": [1376, 736], + "1.91": [1408, 736], + "2.05": [1440, 704], +} + +ASPECT_RATIO_BINS = { + 256: ASPECT_RATIO_256_BIN, + 512: ASPECT_RATIO_512_BIN, + 1024: ASPECT_RATIO_1024_BIN, +} + +logger = logging.get_logger(__name__) + + +class TextPreprocessor: + """Text preprocessing utility for PRXPipeline.""" + + def __init__(self): + """Initialize text preprocessor.""" + self.bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + r"\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) + + def clean_text(self, text: str) -> str: + """Clean text using comprehensive text processing logic.""" + # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py + text = str(text) + text = ul.unquote_plus(text) + text = text.strip().lower() + text = re.sub("", "person", text) + + # Remove all urls: + text = re.sub( + r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))", + "", + text, + ) # regex for urls + + # @ + text = re.sub(r"@[\w\d]+\b", "", text) + + # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs + text = re.sub(r"[\u31c0-\u31ef]+", "", text) + text = re.sub(r"[\u31f0-\u31ff]+", "", text) + text = re.sub(r"[\u3200-\u32ff]+", "", text) + text = re.sub(r"[\u3300-\u33ff]+", "", text) + text = re.sub(r"[\u3400-\u4dbf]+", "", text) + text = re.sub(r"[\u4dc0-\u4dff]+", "", text) + text = re.sub(r"[\u4e00-\u9fff]+", "", text) + + # все виды тире / all types of dash --> "-" + text = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + "-", + text, + ) + + # кавычки к одному стандарту + text = re.sub(r"[`´«»" "¨]", '"', text) + text = re.sub(r"['']", "'", text) + + # " and & + text = re.sub(r""?", "", text) + text = re.sub(r"&", "", text) + + # ip addresses: + text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text) + + # article ids: + text = re.sub(r"\d:\d\d\s+$", "", text) + + # \n + text = re.sub(r"\\n", " ", text) + + # "#123", "#12345..", "123456.." + text = re.sub(r"#\d{1,3}\b", "", text) + text = re.sub(r"#\d{5,}\b", "", text) + text = re.sub(r"\b\d{6,}\b", "", text) + + # filenames: + text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text) + + # Clean punctuation + text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT""" + text = re.sub(r"[\.]{2,}", r" ", text) + + text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT + text = re.sub(r"\s+\.\s+", r" ", text) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, text)) > 3: + text = re.sub(regex2, " ", text) + + # Basic cleaning + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + text = text.strip() + + # Clean alphanumeric patterns + text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640 + text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc + text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231 + + # Common spam patterns + text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text) + text = re.sub(r"(free\s)?download(\sfree)?", "", text) + text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text) + text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text) + text = re.sub(r"\bpage\s+\d+\b", "", text) + + text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a... + text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text) + + # Final cleanup + text = re.sub(r"\b\s+\:\s+", r": ", text) + text = re.sub(r"(\D[,\./])\b", r"\1 ", text) + text = re.sub(r"\s+", " ", text) + + text.strip() + + text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text) + text = re.sub(r"^[\'\_,\-\:;]", r"", text) + text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text) + text = re.sub(r"^\.\S+$", "", text) + + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PRXPipeline + + >>> # Load pipeline with from_pretrained + >>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft") + >>> pipe.to("cuda") + + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] + >>> image.save("prx_output.png") + ``` +""" + + +class PRXPipeline( + DiffusionPipeline, + LoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using PRX Transformer. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + transformer ([`PRXTransformer2DModel`]): + The PRX transformer model to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5GemmaEncoder`]): + Text encoder model for encoding prompts. + tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): + Tokenizer for the text encoder. + vae ([`AutoencoderKL`] or [`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["vae"] + + def __init__( + self, + transformer: PRXTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: T5GemmaEncoder, + tokenizer: T5TokenizerFast | GemmaTokenizerFast | AutoTokenizer, + vae: AutoencoderKL | AutoencoderDC | None = None, + default_sample_size: int | None = DEFAULT_RESOLUTION, + ): + super().__init__() + + if PRXTransformer2DModel is None: + raise ImportError( + "PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed." + ) + + self.text_preprocessor = TextPreprocessor() + self.default_sample_size = default_sample_size + self._guidance_scale = 1.0 + + self.register_modules( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) + + self.register_to_config(default_sample_size=self.default_sample_size) + + if vae is not None: + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + else: + self.image_processor = None + + @property + def vae_scale_factor(self): + if self.vae is None: + return 8 + if hasattr(self.vae, "spatial_compression_ratio"): + return self.vae.spatial_compression_ratio + else: # Flux VAE + return 2 ** (len(self.vae.config.block_out_channels) - 1) + + @property + def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled based on guidance scale.""" + return self._guidance_scale > 1.0 + + @property + def guidance_scale(self): + return self._guidance_scale + + def get_default_resolution(self): + """Determine the default resolution based on the loaded VAE and config. + + Returns: + int: The default sample size (height/width) to use for generation. + """ + default_from_config = getattr(self.config, "default_sample_size", None) + if default_from_config is not None: + return default_from_config + + return DEFAULT_RESOLUTION + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ): + """Prepare initial latents for the diffusion process.""" + if latents is None: + spatial_compression = self.vae_scale_factor + latent_height, latent_width = ( + height // spatial_compression, + width // spatial_compression, + ) + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + prompt_attention_mask: torch.BoolTensor | None = None, + negative_prompt_attention_mask: torch.BoolTensor | None = None, + ): + """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.""" + if device is None: + device = self._execution_device + + if prompt_embeds is None: + if isinstance(prompt, str): + prompt = [prompt] + # Encode the prompts + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) + ) + + # Duplicate embeddings for each generation per prompt + if num_images_per_prompt > 1: + # Repeat prompt embeddings + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # Repeat negative embeddings if using CFG + if do_classifier_free_guidance and negative_prompt_embeds is not None: + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds if do_classifier_free_guidance else None, + negative_prompt_attention_mask if do_classifier_free_guidance else None, + ) + + def _tokenize_prompts(self, prompts: list[str], device: torch.device): + """Tokenize and clean prompts.""" + cleaned = [self.text_preprocessor.clean_text(text) for text in prompts] + tokens = self.tokenizer( + cleaned, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device) + + def _encode_prompt_standard( + self, + prompt: list[str], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + ): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + batch_size = len(prompt) + + if do_classifier_free_guidance: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + prompts_to_encode = negative_prompt + prompt + else: + prompts_to_encode = prompt + + input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device) + + with torch.no_grad(): + embeddings = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + )["last_hidden_state"] + + if do_classifier_free_guidance: + uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0) + uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0) + else: + text_embeddings = embeddings + cross_attn_mask = attention_mask + uncond_text_embeddings = None + uncond_cross_attn_mask = None + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + def check_inputs( + self, + prompt: str | list[str], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + ): + """Check that all inputs are in correct format.""" + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided and `guidance_scale > 1.0`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + + spatial_compression = self.vae_scale_factor + if height % spatial_compression != 0 or width % spatial_compression != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}." + ) + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + prompt_attention_mask: torch.BoolTensor | None = None, + negative_prompt_attention_mask: torch.BoolTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + use_resolution_binning: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str`, *optional*, defaults to `""`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an + empty string. + prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated + from `prompt` input argument. + negative_prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`, + attention mask will be generated from an empty string. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back + to the requested resolution. Useful for generating non-square images at optimal resolutions. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Set height and width + default_resolution = self.get_default_resolution() + height = height or default_resolution + width = width or default_resolution + + if use_resolution_binning: + if self.image_processor is None: + raise ValueError( + "Resolution binning requires a VAE with image_processor, but VAE is not available. " + "Set use_resolution_binning=False or provide a VAE." + ) + if self.default_sample_size not in ASPECT_RATIO_BINS: + raise ValueError( + f"Resolution binning is only supported for default_sample_size in {list(ASPECT_RATIO_BINS.keys())}, " + f"but got {self.default_sample_size}. Set use_resolution_binning=False to disable aspect ratio binning." + ) + aspect_ratio_bin = ASPECT_RATIO_BINS[self.default_sample_size] + + # Store original dimensions + orig_height, orig_width = height, width + # Map to closest resolution in the bin + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + + if self.vae is None and output_type not in ["latent", "pt"]: + raise ValueError( + f"VAE is required for output_type='{output_type}' but it is not available. " + "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs." + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Use execution device (handles offloading scenarios including group offloading) + device = self._execution_device + + self._guidance_scale = guidance_scale + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + # Expose standard names for callbacks parity + prompt_embeds = text_embeddings + negative_prompt_embeds = uncond_text_embeddings + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self.num_timesteps = len(timesteps) + + # 4. Prepare latent variables + if self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + else: + # When vae is None, get latent channels from transformer + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Prepare cross-attention embeddings and masks + if self.do_classifier_free_guidance: + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + else: + ca_embed = text_embeddings + ca_mask = cross_attn_mask + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents if using classifier-free guidance + if self.do_classifier_free_guidance: + latents_in = torch.cat([latents, latents], dim=0) + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + else: + latents_in = latents + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) + + # Forward through transformer + noise_pred = self.transformer( + hidden_states=latents_in, + timestep=t_cont, + encoder_hidden_states=ca_embed, + attention_mask=ca_mask, + return_dict=False, + )[0] + + # Apply CFG + if self.do_classifier_free_guidance: + noise_uncond, noise_text = noise_pred.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if output_type == "latent" or (output_type == "pt" and self.vae is None): + image = latents + else: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + latents = (latents / scaling_factor) + shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + # Resize back to original resolution if using binning + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return PRXPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f43d0ebb0b9f0c2ddf7786da461ead8541cc15f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/__init__.py @@ -0,0 +1,65 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"] + _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] + _import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"] + _import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"] + _import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"] + _import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"] + _import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"] + _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] + _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"] + _import_structure["pipeline_qwenimage_layered"] = ["QwenImageLayeredPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_qwenimage import QwenImagePipeline + from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline + from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline + from .pipeline_qwenimage_edit import QwenImageEditPipeline + from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline + from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline + from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline + from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline + from .pipeline_qwenimage_layered import QwenImageLayeredPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ed06856e6a60128d9286344aaa1d19ffdac6e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..1715aa4d4250ef49aeb172df5db2da0d477ef1d6 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -0,0 +1,759 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImagePipeline + + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by + setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to + generate images that are closely linked to the text `prompt`, usually at the expense of lower image + quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..85a936f9ec24dbe5451faee37997cba8ab569689 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -0,0 +1,991 @@ +# Copyright 2025 Qwen-Image Team, InstantX Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image + >>> from diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageControlNetPipeline + + >>> # QwenImageControlNetModel + >>> controlnet = QwenImageControlNetModel.from_pretrained( + ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = QwenImageControlNetPipeline.from_pretrained( + ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation." + >>> negative_prompt = " " + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... control_image=control_image, + ... controlnet_conditioning_scale=1.0, + ... num_inference_steps=30, + ... true_cfg_scale=4.0, + ... ).images[0] + >>> image.save("qwenimage_cn_union.png") + + >>> # QwenImageMultiControlNetModel + >>> controlnet = QwenImageControlNetModel.from_pretrained( + ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16 + ... ) + >>> controlnet = QwenImageMultiControlNetModel([controlnet]) + >>> pipe = QwenImageControlNetPipeline.from_pretrained( + ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation." + >>> negative_prompt = " " + >>> control_image = load_image( + ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... control_image=[control_image, control_image], + ... controlnet_conditioning_scale=[0.5, 0.5], + ... num_inference_steps=30, + ... true_cfg_scale=4.0, + ... ).images[0] + >>> image.save("qwenimage_cn_union_multi.png") + ``` +""" + + +# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + controlnet: QwenImageControlNetModel | QwenImageMultiControlNetModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by + setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to + generate images that are closely linked to the text `prompt`, usually at the expense of lower image + quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 3. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, QwenImageControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + + if control_image.ndim == 4: + control_image = control_image.unsqueeze(2) + + # vae encode + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to( + device + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device + ) + + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - latents_mean) * latents_std + + control_image = control_image.permute(0, 2, 1, 3, 4) + + # pack + control_image = self._pack_latents( + control_image, + batch_size=control_image.shape[0], + num_channels_latents=num_channels_latents, + height=control_image.shape[3], + width=control_image.shape[4], + ).to(dtype=prompt_embeds.dtype, device=device) + + else: + if isinstance(self.controlnet, QwenImageMultiControlNetModel): + control_images = [] + for control_image_ in control_image: + control_image_ = self.prepare_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + height, width = control_image_.shape[-2:] + + if control_image_.ndim == 4: + control_image_ = control_image_.unsqueeze(2) + + # vae encode + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1) + ).to(device) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(device) + + control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator) + control_image_ = (control_image_ - latents_mean) * latents_std + + control_image_ = control_image_.permute(0, 2, 1, 3, 4) + + # pack + control_image_ = self._pack_latents( + control_image_, + batch_size=control_image_.shape[0], + num_channels_latents=num_channels_latents, + height=control_image_.shape[3], + width=control_image_.shape[4], + ).to(dtype=prompt_embeds.dtype, device=device) + + control_images.append(control_image_) + + control_image = control_images + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet + controlnet_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image, + conditioning_scale=cond_scale, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + return_dict=False, + ) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + controlnet_block_samples=controlnet_block_samples, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + controlnet_block_samples=controlnet_block_samples, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..b1da59cb4f6c8a796beafbc685ef8f594d126cc3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -0,0 +1,941 @@ +# Copyright 2025 Qwen-Image Team, The InstantX Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers.utils import load_image + >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline + + >>> base_model_path = "Qwen/Qwen-Image" + >>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting" + >>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16) + >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained( + ... base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ).to("cuda") + >>> image = load_image( + ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png" + ... ) + >>> mask_image = load_image( + ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png" + ... ) + >>> prompt = "一辆绿色的出租车行驶在路上" + >>> result = pipe( + ... prompt=prompt, + ... control_image=image, + ... control_mask=mask_image, + ... controlnet_conditioning_scale=1.0, + ... width=mask_image.size[0], + ... height=mask_image.size[1], + ... true_cfg_scale=4.0, + ... ).images[0] + >>> image.save("qwenimage_controlnet_inpaint.png") + ``` +""" + + +# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + controlnet: QwenImageControlNetModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + do_resize=True, + do_convert_grayscale=True, + do_normalize=False, + do_binarize=True, + ) + + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_image_with_mask( + self, + image, + mask, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori) + + # Prepare mask + if isinstance(mask, torch.Tensor): + pass + else: + mask = self.mask_processor.preprocess(mask, height=height, width=width) + mask = mask.repeat_interleave(repeat_by, dim=0) + mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori) + + if image.ndim == 4: + image = image.unsqueeze(2) + + if mask.ndim == 4: + mask = mask.unsqueeze(2) + + # Get masked image + masked_image = image.clone() + masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori) + + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device + ) + + # Encode to latents + image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample() + image_latents = (image_latents - latents_mean) * latents_std + image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8]) + + mask = torch.nn.functional.interpolate( + mask, size=(image_latents.shape[-3], image_latents.shape[-2], image_latents.shape[-1]) + ) + mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8]) + + control_image = torch.cat( + [image_latents, mask], dim=1 + ) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8]) + + control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8]) + + # pack + control_image = self._pack_latents( + control_image, + batch_size=control_image.shape[0], + num_channels_latents=control_image.shape[2], + height=control_image.shape[3], + width=control_image.shape[4], + ) + + if do_classifier_free_guidance and not guess_mode: + control_image = torch.cat([control_image] * 2) + + return control_image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 1.0, + control_guidance_start: float | list[float] = 0.0, + control_guidance_end: float | list[float] = 1.0, + control_image: PipelineImageInput = None, + control_mask: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 3. Prepare control image + num_channels_latents = self.transformer.config.in_channels // 4 + if isinstance(self.controlnet, QwenImageControlNetModel): + control_image = self.prepare_image_with_mask( + image=control_image, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # controlnet + controlnet_block_samples = self.controlnet( + hidden_states=latents, + controlnet_cond=control_image.to(dtype=latents.dtype, device=device), + conditioning_scale=cond_scale, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + return_dict=False, + ) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + controlnet_block_samples=controlnet_block_samples, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + controlnet_block_samples=controlnet_block_samples, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..15e72a010ce52b54d1944b6183c1c0ea45373a34 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -0,0 +1,895 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = ( + ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage_edit.png") + ``` +""" + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None + + +class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + image: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..20a2748bc7f9ed484af730c8ac5f3c84be5278e3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -0,0 +1,1118 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageEditInpaintPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe( + ... prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=1.0, num_inference_steps=50 + ... ).images[0] + >>> image.save("qwenimage_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.calculate_dimensions +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height, None + + +class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + image: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + + image_latents = (image_latents - latents_mean) * latents_std + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_latents + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, noise, image_latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if masked_image.dim() == 4: + masked_image = masked_image.unsqueeze(2) + elif masked_image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.") + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == self.latent_channels: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `list[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will ge generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + + # height and width are the same as the calculated height and width + height = calculated_height + width = calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # 3. Preprocess image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + original_image = image + prompt_image = image + image = self.image_processor.preprocess( + image, + height=calculated_height, + width=calculated_width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + image = image.to(dtype=torch.float32) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, noise, image_latents = self.prepare_latents( + image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..5887834585711c1ab52fc242ba9c28631e6253f0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -0,0 +1,886 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPlusPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGB") + >>> prompt = ( + ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors" + ... ) + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage_edit_plus.png") + ``` +""" + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Edit pipeline for image editing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + image: torch.Tensor | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + image: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # QwenImageEditPlusPipeline does not currently support batch_size > 1 + if batch_size > 1: + raise ValueError( + f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " + "Please process prompts one at a time." + ) + + device = self._execution_device + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..42e63f8919a2ddb8ec41995aff7fd0737e65fb4d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -0,0 +1,862 @@ +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImageImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney" + >>> images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95).images[0] + >>> images.save("qwenimage_img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + + image_latents = (image_latents - latents_mean) * latents_std + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + strength, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by + setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to + generate images that are closely linked to the text `prompt`, usually at the expense of lower image + quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..5baf5bf5f77dbe040ac05ebff20ffb85704e6749 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -0,0 +1,1049 @@ +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImageInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85).images[0] + >>> image.save("qwenimage_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + + image_latents = (image_latents - latents_mean) * latents_std + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W'] + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W'] + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, noise, image_latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if masked_image.dim() == 4: + masked_image = masked_image.unsqueeze(2) + elif masked_image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.") + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == self.latent_channels: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by + setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to + generate images that are closely linked to the text `prompt`, usually at the expense of lower image + quality. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `list[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will be generated by `mask_image`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask_image, + strength, + height, + width, + output_type=output_type, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, noise, image_latents = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a44d880f9bc592e5e5a675901e2ae2360a4023 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -0,0 +1,903 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageLayeredPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageLayeredPipeline.from_pretrained("Qwen/Qwen-Image-Layered", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGBA") + >>> prompt = "" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> images = pipe( + ... image, + ... prompt, + ... num_inference_steps=50, + ... true_cfg_scale=4.0, + ... layers=4, + ... resolution=640, + ... cfg_normalize=False, + ... use_en_prompt=True, + ... ).images[0] + >>> for i, image in enumerate(images): + ... image.save(f"{i}.out.png") + ``` +""" + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Layered pipeline for image decomposing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.image_caption_prompt_cn = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# 图像标注器\n你是一个专业的图像标注器。请基于输入图像,撰写图注:\n1. +使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n2. 通过加入以下内容,丰富图注细节:\n - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n - +对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n - 环境细节:例如天气、光照、颜色、纹理、气氛等\n - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n3. +保持真实性与准确性:\n - 不要使用笼统的描述\n - +描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + self.image_caption_prompt_en = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# Image Annotator\nYou are a professional +image annotator. Please write an image caption based on the input image:\n1. Write the caption using natural, +descriptive language without structured formats or rich text.\n2. Enrich caption details by including: \n - Object +attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n - Vision Relations +between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action +relations, comparative relations, causal relations, and so on\n - Environmental details, such as weather, lighting, +colors, textures, atmosphere, and so on\n - Identify the text clearly visible in the image, without translation or +explanation, and highlight it in the caption with quotation marks\n3. Maintain authenticity and accuracy:\n - Avoid +generalizations\n - Describe all visible information in the image, while do not add information not explicitly shown in +the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: str | list[str] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + return_tensors="pt", + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def get_image_caption(self, prompt_image, use_en_prompt=True, device=None): + if use_en_prompt: + prompt = self.image_caption_prompt_en + else: + prompt = self.image_caption_prompt_cn + model_inputs = self.vl_processor( + text=prompt, + images=prompt_image, + padding=True, + return_tensors="pt", + ).to(device) + generated_ids = self.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = self.vl_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return output_text.strip() + + def check_inputs( + self, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers): + latents = latents.view(batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape(batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, layers, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + + latents = latents.reshape(batch_size, layers + 1, channels // (2 * 2), height, width) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + layers, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = ( + batch_size, + layers + 1, + num_channels_latents, + height, + width, + ) ### the generated first image is combined image + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = image_latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) -> (b, f, c, h, w) + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width, 1 + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, layers + 1) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + true_cfg_scale: float = 4.0, + layers: int | None = 4, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + resolution: int = 640, + cfg_normalize: bool = False, + use_en_prompt: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + resolution (`int`, *optional*, defaults to 640): + using different bucket in (640, 1024) to determin the condition and output resolution + cfg_normalize (`bool`, *optional*, defaults to `False`) + whether enable cfg normalization. + use_en_prompt (`bool`, *optional*, defaults to `False`) + automatic caption language if user does not provide caption + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}" + calculated_width, calculated_height = calculate_dimensions( + resolution * resolution, image_size[0] / image_size[1] + ) + height = calculated_height + width = calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + # 2. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + image = image.to(dtype=self.text_encoder.dtype) + + if prompt is None or prompt == "" or prompt == " ": + prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + layers, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + *[ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2) + for _ in range(layers + 1) + ], + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + image_seq_len = latents.shape[1] + base_seqlen = 256 * 256 / 16 / 16 + mu = (image_latents.shape[1] / base_seqlen) ** 0.5 + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + additional_t_cond=is_rgb, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + additional_t_cond=is_rgb, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if cfg_normalize: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + noise_pred = comb_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, layers, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + + b, c, f, h, w = latents.shape + + latents = latents[:, :, 1:] # remove the first frame as it is the orgin input + + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + + image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w + + image = image.squeeze(2) + + image = self.image_processor.postprocess(image, output_type=output_type) + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + + return QwenImagePipelineOutput(images=images) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..91684f35f153516de83cd65ebed28164970b7548 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sana"] = ["SanaPipeline"] + _import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"] + _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] + _import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_sana import SanaPipeline + from .pipeline_sana_controlnet import SanaControlNetPipeline + from .pipeline_sana_sprint import SanaSprintPipeline + from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d095906e06ba975233a2a1577d1dac651c813d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class SanaPipelineOutput(BaseOutput): + """ + Output class for Sana pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana.py new file mode 100644 index 0000000000000000000000000000000000000000..17e0be9ba7aa23a9c810e49cce88b3796a1c7ae1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana.py @@ -0,0 +1,1040 @@ +# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaPipeline + + >>> pipe = SanaPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", torch_dtype=torch.float32 + ... ) + >>> pipe.to("cuda") + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.transformer = pipe.transformer.to(torch.bfloat16) + + >>> image = pipe(prompt='a cyberpunk cat with a neon sign that says "Sana"')[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + timestep = timestep * self.transformer.config.timestep_scale + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d976c7035d9d136b28b861a2803c7d723c82659b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -0,0 +1,1135 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaControlNetModel, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaControlNetPipeline + >>> from diffusers.utils import load_image + + >>> pipe = SanaControlNetPipeline.from_pretrained( + ... "ishan24/Sana_600M_1024px_ControlNetPlus_diffusers", + ... variant="fp16", + ... torch_dtype={"default": torch.bfloat16, "controlnet": torch.float16, "transformer": torch.float16}, + ... device_map="balanced", + ... ) + >>> cond_image = load_image( + ... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png" + ... ) + >>> prompt = 'a cat with a neon sign that says "Sana"' + >>> image = pipe( + ... prompt, + ... control_image=cond_image, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->controlnet->transformer->vae" + _callback_tensor_inputs = ["latents", "control_image", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + controlnet: SanaControlNetModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + controlnet=controlnet, + scheduler=scheduler, + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 4.5, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 1.0, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,: + `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare control image + if isinstance(self.controlnet, SanaControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent + control_image = control_image * self.vae.config.scaling_factor + else: + raise ValueError("`controlnet` must be of type `SanaControlNetModel`.") + + # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 6. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + controlnet_dtype = self.controlnet.dtype + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # controlnet(s) inference + controlnet_block_samples = self.controlnet( + latent_model_input.to(dtype=controlnet_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + )[0] + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_block_samples=tuple(t.to(dtype=transformer_dtype) for t in controlnet_block_samples), + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint.py new file mode 100644 index 0000000000000000000000000000000000000000..c85f05275f51165e154306614b34cc601e9988a5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -0,0 +1,922 @@ +# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaSprintPipeline + + >>> pipe = SanaSprintPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + num_inference_steps, + timesteps, + max_timesteps, + intermediate_timesteps, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_inference_steps: int = 2, + timesteps: list[int] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, + guidance_scale: float = 4.5, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + max_timesteps (`float`, *optional*, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + ) = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + sigmas=None, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + latents = latents * self.scheduler.config.sigma_data + + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) + guidance = guidance * self.transformer.config.guidance_embeds_scale + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + timesteps = timesteps[:-1] + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + latents_model_input = latents / self.scheduler.config.sigma_data + + scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) + + scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1) + latent_model_input = latents_model_input * torch.sqrt( + scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 + ) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + noise_pred = ( + (1 - 2 * scm_timestep_expanded) * latent_model_input + + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred + ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2) + noise_pred = noise_pred.float() * self.scheduler.config.sigma_data + + # compute previous image: x_t -> x_t-1 + latents, denoised = self.scheduler.step( + noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = denoised / self.scheduler.config.sigma_data + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..a17c494e88eb1445a62970c8907c3ae9a0de1d07 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -0,0 +1,1006 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN +from .pipeline_output import SanaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaSprintImg2ImgPipeline + >>> from diffusers.utils.loading_utils import load_image + + >>> pipe = SanaSprintImg2ImgPipeline.from_pretrained( + ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png" + ... ) + + + >>> image = pipe(prompt="a cute pink bear", image=image, strength=0.5, height=832, width=480).images[0] + >>> image[0].save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641). + """ + + # fmt: off + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana.pipeline_sana_sprint.SanaSprintPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + strength, + height, + width, + num_inference_steps, + timesteps, + max_timesteps, + intermediate_timesteps, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + prompt_attention_mask=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if timesteps is not None and len(timesteps) != num_inference_steps + 1: + raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.") + + if timesteps is not None and max_timesteps is not None: + raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.") + + if timesteps is None and max_timesteps is None: + raise ValueError("Should provide either `timesteps` or `max_timesteps`.") + + if intermediate_timesteps is not None and num_inference_steps != 2: + raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image: PipelineImageInput, + width: int, + height: int, + device: torch.device, + dtype: torch.dtype, + ): + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.unsqueeze(0) + # Resize if current dimensions do not match target dimensions. + if image.shape[2] != height or image.shape[3] != width: + image = F.interpolate(image, size=(height, width), mode="bilinear", align_corners=False) + + image = self.image_processor.preprocess(image, height=height, width=width) + + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image = image.to(device=device, dtype=dtype) + + return image + + def prepare_latents( + self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if image.shape[1] != num_channels_latents: + image = self.vae.encode(image).latent + image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.scheduler.config.sigma_data + latents = torch.cos(timestep) * image_latents + torch.sin(timestep) * noise + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + num_inference_steps: int = 2, + timesteps: list[int] = None, + max_timesteps: float = 1.57080, + intermediate_timesteps: float = 1.3, + guidance_scale: float = 4.5, + image: PipelineImageInput = None, + strength: float = 0.6, + num_images_per_prompt: int | None = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + max_timesteps (`float`, *optional*, defaults to 1.57080): + The maximum timestep value used in the SCM scheduler. + intermediate_timesteps (`float`, *optional*, defaults to 1.3): + The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2). + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt=prompt, + strength=strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 2. Preprocess image + init_image = self.prepare_image(image, width, height, device, self.vae.dtype) + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + ) = self.encode_prompt( + prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=None, + max_timesteps=max_timesteps, + intermediate_timesteps=intermediate_timesteps, + ) + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(0) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1] + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # I think this is redundant given the scaling in prepare_latents + # latents = latents * self.scheduler.config.sigma_data + + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype) + guidance = guidance * self.transformer.config.guidance_embeds_scale + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + timesteps = timesteps[:-1] + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + latents_model_input = latents / self.scheduler.config.sigma_data + + scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep)) + + scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1) + latent_model_input = latents_model_input * torch.sqrt( + scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2 + ) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + + noise_pred = ( + (1 - 2 * scm_timestep_expanded) * latent_model_input + + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred + ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2) + noise_pred = noise_pred.float() * self.scheduler.config.sigma_data + + # compute previous image: x_t -> x_t-1 + latents, denoised = self.scheduler.step( + noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = denoised / self.scheduler.config.sigma_data + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73e224bf749d358e46a2cd4e762c7875b8eeb761 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"] + _import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_sana_video import SanaVideoPipeline + from .pipeline_sana_video_i2v import SanaImageToVideoPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..a625ad3e1d3449b48308bb7ce4a51fcfcc36b23b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput + + +@dataclass +class SanaVideoPipelineOutput(BaseOutput): + r""" + Output class for Sana-Video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_sana_video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_sana_video.py new file mode 100644 index 0000000000000000000000000000000000000000..8b44dfc1143c4dba1e44b2c41d5ed1835c6cbf17 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -0,0 +1,1017 @@ +# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SanaVideoPipelineOutput + + +ASPECT_RATIO_480_BIN = { + "0.5": [448.0, 896.0], + "0.57": [480.0, 832.0], + "0.68": [528.0, 768.0], + "0.78": [560.0, 720.0], + "1.0": [624.0, 624.0], + "1.13": [672.0, 592.0], + "1.29": [720.0, 560.0], + "1.46": [768.0, 528.0], + "1.67": [816.0, 496.0], + "1.75": [832.0, 480.0], + "2.0": [896.0, 448.0], +} + + +ASPECT_RATIO_720_BIN = { + "0.5": [672.0, 1344.0], + "0.57": [704.0, 1280.0], + "0.68": [800.0, 1152.0], + "0.78": [832.0, 1088.0], + "1.0": [960.0, 960.0], + "1.13": [1024.0, 896.0], + "1.29": [1088.0, 832.0], + "1.46": [1152.0, 800.0], + "1.67": [1248.0, 736.0], + "1.75": [1280.0, 704.0], + "2.0": [1344.0, 672.0], +} + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaVideoPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") + >>> pipe.transformer.to(torch.bfloat16) + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + >>> motion_score = 30 + + >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional." + >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + >>> motion_prompt = f" motion score: {motion_score}." + >>> prompt = prompt + motion_prompt + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), + ... ).frames[0] + + >>> export_to_video(output, "sana-video-output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits + from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all + pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + The tokenizer used to tokenize the prompt. + text_encoder ([`Gemma2PreTrainedModel`]): + Text encoder model to encode the input prompts. + vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`SanaVideoTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`DPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC | AutoencoderKLWan, + transformer: SanaVideoTransformer3DModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.vae_scale_factor = self.vae_scale_factor_spatial + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of videos that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + height: int = 480, + width: int = 832, + frames: int = 81, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaVideoPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to 480): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 832): + The width in pixels of the generated video. + frames (`int`, *optional*, defaults to 81): + The number of frames in the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between mp4 or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, + they are resized back to the requested resolution. Useful for generating non-square videos. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated videos + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 30: + aspect_ratio_bin = ASPECT_RATIO_480_BIN + elif self.transformer.config.sample_size == 22: + aspect_ratio_bin = ASPECT_RATIO_720_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + try: + video = self.vae.decode(latents, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + + if use_resolution_binning: + video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) + + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SanaVideoPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..b90d7c6f5a60ae7bdeadd067a9c462fe1f7e3d50 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py @@ -0,0 +1,1066 @@ +# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable + +import PIL +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SanaVideoPipelineOutput +from .pipeline_sana_video import ASPECT_RATIO_480_BIN, ASPECT_RATIO_720_BIN + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = SanaImageToVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers") + >>> pipe.transformer.to(torch.bfloat16) + >>> pipe.text_encoder.to(torch.bfloat16) + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + >>> motion_score = 30 + + >>> prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle." + >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience." + >>> motion_prompt = f" motion score: {motion_score}." + >>> prompt = prompt + motion_prompt + >>> image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png") + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=832, + ... frames=81, + ... guidance_scale=6, + ... num_inference_steps=50, + ... generator=torch.Generator(device="cuda").manual_seed(42), + ... ).frames[0] + + >>> export_to_video(output, "sana-ti2v-output.mp4", fps=16) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for image/text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model + inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all + pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]): + The tokenizer used to tokenize the prompt. + text_encoder ([`Gemma2PreTrainedModel`]): + Text encoder model to encode the input prompts. + vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer ([`SanaVideoTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC | AutoencoderKLWan, + transformer: SanaVideoTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + + self.vae_scale_factor = self.vae_scale_factor_spatial + + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size[1] if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size[0] if getattr(self, "transformer") is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana_video.pipeline_sana_video.SanaVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: torch.device | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: list[str] | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of videos that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + image = image.to(device=device, dtype=self.vae.dtype) + + if isinstance(generator, list): + image_latents = [retrieve_latents(self.vae.encode(image), sample_mode="argmax") for _ in generator] + image_latents = torch.cat(image_latents) + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, -1, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + image_latents.device, image_latents.dtype + ) + image_latents = (image_latents - latents_mean) * latents_std + + latents[:, :, 0:1] = image_latents.to(dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str = "", + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + height: int = 480, + width: int = 832, + frames: int = 81, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: list[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> SanaVideoPipelineOutput | tuple: + """ + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the video generation on. The first frame of the generated video will be + conditioned on this image. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to + the text `prompt`, usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + height (`int`, *optional*, defaults to 480): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to 832): + The width in pixels of the generated video. + frames (`int`, *optional*, defaults to 81): + The number of frames in the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between mp4 or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos, + they are resized back to the requested resolution. Useful for generating non-square videos. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`list[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated videos + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 30: + aspect_ratio_bin = ASPECT_RATIO_480_BIN + elif self.transformer.config.sample_size == 22: + aspect_ratio_bin = ASPECT_RATIO_720_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + image, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + latents = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + latent_channels, + height, + width, + frames, + torch.float32, + device, + generator, + latents, + ) + + conditioning_mask = latents.new_zeros( + batch_size, + 1, + latents.shape[2] // self.transformer_temporal_patch_size, + latents.shape[3] // self.transformer_spatial_patch_size, + latents.shape[4] // self.transformer_spatial_patch_size, + ) + conditioning_mask[:, :, 0] = 1.0 + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(conditioning_mask.shape) + timestep = timestep * (1 - conditioning_mask) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step( + noise_pred, t, noise_latents, **extra_step_kwargs, return_dict=False + )[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + try: + video = self.vae.decode(latents, return_dict=False)[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + + if use_resolution_binning: + video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height) + + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SanaVideoPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70f5b1a547c4b90e28109843ae3be2fca2e98c88 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["SemanticStableDiffusionPipelineOutput"] + _import_structure["pipeline_semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5429ce2a8d1f5378caef2a726856f5b8ad943e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_output.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class SemanticStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`list[bool]`) + list indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..05d28896e1175800d5aca8798e59baaef5115f92 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -0,0 +1,732 @@ +import inspect +from itertools import repeat +from typing import Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import SemanticStableDiffusionPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + _last_supported_version = "0.33.1" + r""" + Pipeline for text-to-image generation using Stable Diffusion with latent editing. + + This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass + documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular + device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`Q16SafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + editing_prompt: str | list[str] | None = None, + editing_prompt_embeddings: torch.Tensor | None = None, + reverse_editing_direction: bool | list[bool] | None = False, + edit_guidance_scale: float | list[float] | None = 5, + edit_warmup_steps: int | list[int] | None = 10, + edit_cooldown_steps: int | list[int] | None = None, + edit_threshold: float | list[float] | None = 0.9, + edit_momentum_scale: float | None = 0.1, + edit_mom_beta: float | None = 0.4, + edit_weights: list[float] | None = None, + sem_guidance: list[torch.Tensor] | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + editing_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to use for semantic guidance. Semantic guidance is disabled by setting + `editing_prompt = None`. Guidance direction of prompt should be specified via + `reverse_editing_direction`. + editing_prompt_embeddings (`torch.Tensor`, *optional*): + Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be + specified via `reverse_editing_direction`. + reverse_editing_direction (`bool` or `list[bool]`, *optional*, defaults to `False`): + Whether the corresponding prompt in `editing_prompt` should be increased or decreased. + edit_guidance_scale (`float` or `list[float]`, *optional*, defaults to 5): + Guidance scale for semantic guidance. If provided as a list, values should correspond to + `editing_prompt`. + edit_warmup_steps (`float` or `list[float]`, *optional*, defaults to 10): + Number of diffusion steps (for each prompt) for which semantic guidance is not applied. Momentum is + calculated for those steps and applied once all warmup periods are over. + edit_cooldown_steps (`float` or `list[float]`, *optional*, defaults to `None`): + Number of diffusion steps (for each prompt) after which semantic guidance is longer applied. + edit_threshold (`float` or `list[float]`, *optional*, defaults to 0.9): + Threshold of semantic guidance. + edit_momentum_scale (`float`, *optional*, defaults to 0.1): + Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0, + momentum is disabled. Momentum is already built up during warmup (for diffusion steps smaller than + `sld_warmup_steps`). Momentum is only added to latent guidance once all warmup periods are finished. + edit_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous + momentum is kept. Momentum is already built up during warmup (for diffusion steps smaller than + `edit_warmup_steps`). + edit_weights (`list[float]`, *optional*, defaults to `None`): + Indicates how much each individual concept should influence the overall guidance. If no weights are + provided all concepts are applied equally. + sem_guidance (`list[torch.Tensor]`, *optional*): + list of pre-generated guidance vectors to be applied at generation. Length of the list has to + correspond to `num_inference_steps`. + + Examples: + + ```py + >>> import torch + >>> from diffusers import SemanticStableDiffusionPipeline + + >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> out = pipe( + ... prompt="a photo of the face of a woman", + ... num_images_per_prompt=1, + ... guidance_scale=7, + ... editing_prompt=[ + ... "smiling, smile", # Concepts to apply + ... "glasses, wearing glasses", + ... "curls, wavy hair, curly hair", + ... "beard, full beard, mustache", + ... ], + ... reverse_editing_direction=[ + ... False, + ... False, + ... False, + ... False, + ... ], # Direction of guidance i.e. increase all concepts + ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept + ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept + ... edit_threshold=[ + ... 0.99, + ... 0.975, + ... 0.925, + ... 0.96, + ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions + ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance + ... edit_mom_beta=0.6, # Momentum beta + ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other + ... ) + >>> image = out.images[0] + ``` + + Returns: + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, + [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images and the second element + is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" + (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + if editing_prompt: + enable_edit_guidance = True + if isinstance(editing_prompt, str): + editing_prompt = [editing_prompt] + enabled_editing_prompts = len(editing_prompt) + elif editing_prompt_embeddings is not None: + enable_edit_guidance = True + enabled_editing_prompts = editing_prompt_embeddings.shape[0] + else: + enabled_editing_prompts = 0 + enable_edit_guidance = False + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if enable_edit_guidance: + # get safety text embeddings + if editing_prompt_embeddings is None: + edit_concepts_input = self.tokenizer( + [x for item in editing_prompt for x in repeat(item, batch_size)], + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + + edit_concepts_input_ids = edit_concepts_input.input_ids + + if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode( + edit_concepts_input_ids[:, self.tokenizer.model_max_length :] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] + edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0] + else: + edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed_edit, seq_len_edit, _ = edit_concepts.shape + edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) + edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if enable_edit_guidance: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # get the initial random noise unless the user supplied it + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # Initialize edit_momentum to None + edit_momentum = None + + self.uncond_estimates = None + self.text_estimates = None + self.edit_estimates = None + self.sem_guidance = None + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + noise_pred_edit_concepts = noise_pred_out[2:] + + # default text guidance + noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) + # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) + + if self.uncond_estimates is None: + self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) + self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() + + if self.text_estimates is None: + self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + self.text_estimates[i] = noise_pred_text.detach().cpu() + + if self.edit_estimates is None and enable_edit_guidance: + self.edit_estimates = torch.zeros( + (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) + ) + + if self.sem_guidance is None: + self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) + + if edit_momentum is None: + edit_momentum = torch.zeros_like(noise_guidance) + + if enable_edit_guidance: + concept_weights = torch.zeros( + (len(noise_pred_edit_concepts), noise_guidance.shape[0]), + device=device, + dtype=noise_guidance.dtype, + ) + noise_guidance_edit = torch.zeros( + (len(noise_pred_edit_concepts), *noise_guidance.shape), + device=device, + dtype=noise_guidance.dtype, + ) + # noise_guidance_edit = torch.zeros_like(noise_guidance) + warmup_inds = [] + for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): + self.edit_estimates[i, c] = noise_pred_edit_concept + if isinstance(edit_guidance_scale, list): + edit_guidance_scale_c = edit_guidance_scale[c] + else: + edit_guidance_scale_c = edit_guidance_scale + + if isinstance(edit_threshold, list): + edit_threshold_c = edit_threshold[c] + else: + edit_threshold_c = edit_threshold + if isinstance(reverse_editing_direction, list): + reverse_editing_direction_c = reverse_editing_direction[c] + else: + reverse_editing_direction_c = reverse_editing_direction + if edit_weights: + edit_weight_c = edit_weights[c] + else: + edit_weight_c = 1.0 + if isinstance(edit_warmup_steps, list): + edit_warmup_steps_c = edit_warmup_steps[c] + else: + edit_warmup_steps_c = edit_warmup_steps + + if isinstance(edit_cooldown_steps, list): + edit_cooldown_steps_c = edit_cooldown_steps[c] + elif edit_cooldown_steps is None: + edit_cooldown_steps_c = i + 1 + else: + edit_cooldown_steps_c = edit_cooldown_steps + if i >= edit_warmup_steps_c: + warmup_inds.append(c) + if i >= edit_cooldown_steps_c: + noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) + continue + + noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond + # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) + + tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) + if reverse_editing_direction_c: + noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 + concept_weights[c, :] = tmp_weights + + noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c + + # torch.quantile function expects float32 + if noise_guidance_edit_tmp.dtype == torch.float32: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), + edit_threshold_c, + dim=2, + keepdim=False, + ) + else: + tmp = torch.quantile( + torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), + edit_threshold_c, + dim=2, + keepdim=False, + ).to(noise_guidance_edit_tmp.dtype) + + noise_guidance_edit_tmp = torch.where( + torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], + noise_guidance_edit_tmp, + torch.zeros_like(noise_guidance_edit_tmp), + ) + noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp + + # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp + + warmup_inds = torch.tensor(warmup_inds).to(device) + if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: + concept_weights = concept_weights.to("cpu") # Offload to cpu + noise_guidance_edit = noise_guidance_edit.to("cpu") + + concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds) + concept_weights_tmp = torch.where( + concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp + ) + concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) + # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) + + noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) + noise_guidance_edit_tmp = torch.einsum( + "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp + ) + noise_guidance = noise_guidance + noise_guidance_edit_tmp + + self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() + + del noise_guidance_edit_tmp + del concept_weights_tmp + concept_weights = concept_weights.to(device) + noise_guidance_edit = noise_guidance_edit.to(device) + + concept_weights = torch.where( + concept_weights < 0, torch.zeros_like(concept_weights), concept_weights + ) + + concept_weights = torch.nan_to_num(concept_weights) + + noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) + noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device) + + noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum + + edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit + + if warmup_inds.shape[0] == len(noise_pred_edit_concepts): + noise_guidance = noise_guidance + noise_guidance_edit + self.sem_guidance[i] = noise_guidance_edit.detach().cpu() + + if sem_guidance is not None: + edit_guidance = sem_guidance[i].to(device) + noise_guidance = noise_guidance + edit_guidance + + noise_pred = noise_pred_uncond + noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ed563c4a51f6e627c06711b60fe3a0709ff22f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/__init__.py @@ -0,0 +1,71 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["camera"] = ["create_pan_cameras"] + _import_structure["pipeline_shap_e"] = ["ShapEPipeline"] + _import_structure["pipeline_shap_e_img2img"] = ["ShapEImg2ImgPipeline"] + _import_structure["renderer"] = [ + "BoundingBoxVolume", + "ImportanceRaySampler", + "MLPNeRFModelOutput", + "MLPNeRSTFModel", + "ShapEParamsProjModel", + "ShapERenderer", + "StratifiedRaySampler", + "VoidNeRFModel", + ] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .camera import create_pan_cameras + from .pipeline_shap_e import ShapEPipeline + from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline + from .renderer import ( + BoundingBoxVolume, + ImportanceRaySampler, + MLPNeRFModelOutput, + MLPNeRSTFModel, + ShapEParamsProjModel, + ShapERenderer, + StratifiedRaySampler, + VoidNeRFModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/camera.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..81807b6ff35a9e3abf400fd7cddfc710244f367b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/camera.py @@ -0,0 +1,146 @@ +# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class DifferentiableProjectiveCamera: + """ + Implements a batch, differentiable, standard pinhole camera + """ + + origin: torch.Tensor # [batch_size x 3] + x: torch.Tensor # [batch_size x 3] + y: torch.Tensor # [batch_size x 3] + z: torch.Tensor # [batch_size x 3] + width: int + height: int + x_fov: float + y_fov: float + shape: tuple[int] + + def __post_init__(self): + assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] + assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 + assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 + + def resolution(self): + return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) + + def fov(self): + return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) + + def get_image_coords(self) -> torch.Tensor: + """ + :return: coords of shape (width * height, 2) + """ + pixel_indices = torch.arange(self.height * self.width) + coords = torch.stack( + [ + pixel_indices % self.width, + torch.div(pixel_indices, self.width, rounding_mode="trunc"), + ], + axis=1, + ) + return coords + + @property + def camera_rays(self): + batch_size, *inner_shape = self.shape + inner_batch_size = int(np.prod(inner_shape)) + + coords = self.get_image_coords() + coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) + rays = self.get_camera_rays(coords) + + rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) + + return rays + + def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + batch_size, *shape, n_coords = coords.shape + assert n_coords == 2 + assert batch_size == self.origin.shape[0] + + flat = coords.view(batch_size, -1, 2) + + res = self.resolution() + fov = self.fov() + + fracs = (flat.float() / (res - 1)) * 2 - 1 + fracs = fracs * torch.tan(fov / 2) + + fracs = fracs.view(batch_size, -1, 2) + directions = ( + self.z.view(batch_size, 1, 3) + + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays = torch.stack( + [ + torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), + directions, + ], + dim=2, + ) + return rays.view(batch_size, *shape, 2, 3) + + def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return DifferentiableProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + +def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: + origins = [] + xs = [] + ys = [] + zs = [] + for theta in np.linspace(0, 2 * np.pi, num=20): + z = np.array([np.sin(theta), np.cos(theta), -0.5]) + z /= np.sqrt(np.sum(z**2)) + origin = -z * 4 + x = np.array([np.cos(theta), -np.sin(theta), 0.0]) + y = np.cross(z, x) + origins.append(origin) + xs.append(x) + ys.append(y) + zs.append(z) + return DifferentiableProjectiveCamera( + origin=torch.from_numpy(np.stack(origins, axis=0)).float(), + x=torch.from_numpy(np.stack(xs, axis=0)).float(), + y=torch.from_numpy(np.stack(ys, axis=0)).float(), + z=torch.from_numpy(np.stack(zs, axis=0)).float(), + width=size, + height=size, + x_fov=0.7, + y_fov=0.7, + shape=(1, len(xs)), + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/pipeline_shap_e.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/pipeline_shap_e.py new file mode 100644 index 0000000000000000000000000000000000000000..44967dfb334922a67f50854434be9602a8f7023b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -0,0 +1,345 @@ +# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import PriorTransformer +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .renderer import ShapERenderer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 15.0 + >>> prompt = "a shark" + + >>> images = pipe( + ... prompt, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "shark_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`]. + + Args: + images (`torch.Tensor`) + A list of images for 3D rendering. + """ + + images: list[list[PIL.Image.Image]] | list[list[np.ndarray]] + + +class ShapEPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with the NeRF method. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with the `prior` model to generate image embedding. + shap_e_renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP to create 3D objects with the NeRF + rendering method. + """ + + model_cpu_offload_seq = "text_encoder->prior" + _exclude_from_cpu_offload = ["shap_e_renderer"] + + def __init__( + self, + prior: PriorTransformer, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: HeunDiscreteScheduler, + shap_e_renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + shap_e_renderer=shap_e_renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + len(prompt) if isinstance(prompt, list) else 1 + + # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file + self.tokenizer.pad_token_id = 0 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + prompt_embeds = text_encoder_output.text_embeds + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # in Shap-E it normalize the prompt_embeds and then later rescale it + prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # Rescale the features to have unit variance + prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds + + return prompt_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: str | None = "pil", # pil, np, latent, mesh + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + frame_size (`int`, *optional*, default to 64): + The width and height of each image frame of the generated 3D output. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`), `"latent"` (`torch.Tensor`), or mesh ([`MeshDecoderOutput`]). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain + tuple. + + Examples: + + Returns: + [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=prompt_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + # Offload all models + self.maybe_free_model_hooks() + + if output_type not in ["np", "pil", "latent", "mesh"]: + raise ValueError( + f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" + ) + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + if output_type == "mesh": + for i, latent in enumerate(latents): + mesh = self.shap_e_renderer.decode_to_mesh( + latent[None, :], + device, + ) + images.append(mesh) + + else: + # np, pil + for i, latent in enumerate(latents): + image = self.shap_e_renderer.decode_to_image( + latent[None, :], + device, + size=frame_size, + ) + images.append(image) + + images = torch.stack(images) + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..964db30e7de2588f7011d96ae184ba4c25103ff5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -0,0 +1,332 @@ +# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModel + +from ...models import PriorTransformer +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .renderer import ShapERenderer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e-img2img" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 3.0 + >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" + >>> image = load_image(image_url).convert("RGB") + + >>> images = pipe( + ... image, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`]. + + Args: + images (`torch.Tensor`) + A list of images for 3D rendering. + """ + + images: PIL.Image.Image | np.ndarray + + +class ShapEImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with the NeRF method from an image. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`~transformers.CLIPVisionModel`]): + Frozen image-encoder. + image_processor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to process images. + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with the `prior` model to generate image embedding. + shap_e_renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP to create 3D objects with the NeRF + rendering method. + """ + + model_cpu_offload_seq = "image_encoder->prior" + _exclude_from_cpu_offload = ["shap_e_renderer"] + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + scheduler: HeunDiscreteScheduler, + shap_e_renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + image_encoder=image_encoder, + image_processor=image_processor, + scheduler=scheduler, + shap_e_renderer=shap_e_renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_image( + self, + image, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + if isinstance(image, list) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) + + image = image.to(dtype=self.image_encoder.dtype, device=device) + + image_embeds = self.image_encoder(image)["last_hidden_state"] + image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 + + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: str | None = "pil", # pil, np, latent, mesh + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + latents as image, but if passing latents directly it is not encoded again. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + frame_size (`int`, *optional*, default to 64): + The width and height of each image frame of the generated 3D output. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`), `"latent"` (`torch.Tensor`), or mesh ([`MeshDecoderOutput`]). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] instead of a plain + tuple. + + Examples: + + Returns: + [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): + batch_size = len(image) + else: + raise ValueError( + f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `list[PIL.Image.Image]` or `list[torch.Tensor]` but is {type(image)}" + ) + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + if latents is None: + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=image_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type not in ["np", "pil", "latent", "mesh"]: + raise ValueError( + f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" + ) + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + if output_type == "mesh": + for i, latent in enumerate(latents): + mesh = self.shap_e_renderer.decode_to_mesh( + latent[None, :], + device, + ) + images.append(mesh) + + else: + # np, pil + for i, latent in enumerate(latents): + image = self.shap_e_renderer.decode_to_image( + latent[None, :], + device, + size=frame_size, + ) + images.append(image) + + images = torch.stack(images) + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/renderer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2058c887fc187060e94a45534129c8efcfda04 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/shap_e/renderer.py @@ -0,0 +1,1049 @@ +# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...utils import BaseOutput +from .camera import create_pan_cameras + + +def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: + r""" + Sample from the given discrete probability distribution with replacement. + + The i-th bin is assumed to have mass pmf[i]. + + Args: + pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() + n_samples: number of samples + + Return: + indices sampled with replacement + """ + + *shape, support_size, last_dim = pmf.shape + assert last_dim == 1 + + cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) + inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) + + return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) + + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://huggingface.co/papers/2210.04628 + """ + if min_deg == max_deg: + return x + + scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + + +def encode_position(position): + return posenc_nerf(position, min_deg=0, max_deg=15) + + +def encode_direction(position, direction=None): + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + else: + return posenc_nerf(direction, min_deg=0, max_deg=8) + + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") + + +def integrate_samples(volume_range, ts, density, channels): + r""" + Function integrating the model output. + + Args: + volume_range: Specifies the integral range [t0, t1] + ts: timesteps + density: torch.Tensor [batch_size, *shape, n_samples, 1] + channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] + returns: + channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density + *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume + ) + """ + + # 1. Calculate the weights + _, _, dt = volume_range.partition(ts) + ddensity = density * dt + + mass = torch.cumsum(ddensity, dim=-2) + transmittance = torch.exp(-mass[..., -1, :]) + + alphas = 1.0 - torch.exp(-ddensity) + Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) + # This is the probability of light hitting and reflecting off of + # something at depth [..., i, :]. + weights = alphas * Ts + + # 2. Integrate channels + channels = torch.sum(channels * weights, dim=-2) + + return channels, weights, transmittance + + +def volume_query_points(volume, grid_size): + indices = torch.arange(grid_size**3, device=volume.bbox_min.device) + zs = indices % grid_size + ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size + xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size + combined = torch.stack([xs, ys, zs], dim=1) + return (combined.float() / (grid_size - 1)) * (volume.bbox_max - volume.bbox_min) + volume.bbox_min + + +def _convert_srgb_to_linear(u: torch.Tensor): + return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) + + +def _create_flat_edge_indices( + flat_cube_indices: torch.Tensor, + grid_size: tuple[int, int, int], +): + num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] + y_offset = num_xs + num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] + z_offset = num_xs + num_ys + return torch.stack( + [ + # Edges spanning x-axis. + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + # Edges spanning y-axis. + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + # Edges spanning z-axis. + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ], + dim=-1, + ) + + +class VoidNeRFModel(nn.Module): + """ + Implements the default empty space model where all queries are rendered as background. + """ + + def __init__(self, background, channel_scale=255.0): + super().__init__() + background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) + + self.register_buffer("background", background) + + def forward(self, position): + background = self.background[None].to(position.device) + + shape = position.shape[:-1] + ones = [1] * (len(shape) - 1) + n_channels = background.shape[-1] + background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) + + return background + + +@dataclass +class VolumeRange: + t0: torch.Tensor + t1: torch.Tensor + intersected: torch.Tensor + + def __post_init__(self): + assert self.t0.shape == self.t1.shape == self.intersected.shape + + def partition(self, ts): + """ + Partitions t0 and t1 into n_samples intervals. + + Args: + ts: [batch_size, *shape, n_samples, 1] + + Return: + + lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, + *shape, n_samples, 1] + + where + ts \\in [lower, upper] deltas = upper - lower + """ + + mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 + lower = torch.cat([self.t0[..., None, :], mids], dim=-2) + upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) + delta = upper - lower + assert lower.shape == upper.shape == delta.shape == ts.shape + return lower, upper, delta + + +class BoundingBoxVolume(nn.Module): + """ + Axis-aligned bounding box defined by the two opposite corners. + """ + + def __init__( + self, + *, + bbox_min, + bbox_max, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + ): + """ + Args: + bbox_min: the left/bottommost corner of the bounding box + bbox_max: the other corner of the bounding box + min_dist: all rays should start at least this distance away from the origin. + """ + super().__init__() + + self.min_dist = min_dist + self.min_t_range = min_t_range + + self.bbox_min = torch.tensor(bbox_min) + self.bbox_max = torch.tensor(bbox_max) + self.bbox = torch.stack([self.bbox_min, self.bbox_max]) + assert self.bbox.shape == (2, 3) + assert min_dist >= 0.0 + assert min_t_range > 0.0 + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: torch.Tensor | None = None, + epsilon=1e-6, + ): + """ + Args: + origin: [batch_size, *shape, 3] + direction: [batch_size, *shape, 3] + t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + params: Optional meta parameters in case Volume is parametric + epsilon: to stabilize calculations + + Return: + A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with + the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to + be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + ones = [1] * len(shape) + bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) + + def _safe_divide(a, b, epsilon=1e-6): + return a / torch.where(b < 0, b - epsilon, b + epsilon) + + ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) + + # Cases to think about: + # + # 1. t1 <= t0: the ray does not pass through the AABB. + # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. + # 3. t0 <= 0 <= t1: the ray starts from inside the BB + # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. + # + # 1 and 4 are clearly handled from t0 < t1 below. + # Making t0 at least min_dist (>= 0) takes care of 2 and 3. + t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) + t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values + assert t0.shape == t1.shape == (batch_size, *shape, 1) + if t0_lower is not None: + assert t0.shape == t0_lower.shape + t0 = torch.maximum(t0, t0_lower) + + intersected = t0 + self.min_t_range < t1 + t0 = torch.where(intersected, t0, torch.zeros_like(t0)) + t1 = torch.where(intersected, t1, torch.ones_like(t1)) + + return VolumeRange(t0=t0, t1=t1, intersected=intersected) + + +class StratifiedRaySampler(nn.Module): + """ + Instead of fixed intervals, a sample is drawn uniformly at random from each interval. + """ + + def __init__(self, depth_mode: str = "linear"): + """ + :param depth_mode: linear samples ts linearly in depth. harmonic ensures + closer points are sampled more densely. + """ + self.depth_mode = depth_mode + assert self.depth_mode in ("linear", "geometric", "harmonic") + + def sample( + self, + t0: torch.Tensor, + t1: torch.Tensor, + n_samples: int, + epsilon: float = 1e-3, + ) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + ones = [1] * (len(t0.shape) - 1) + ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) + + if self.depth_mode == "linear": + ts = t0 * (1.0 - ts) + t1 * ts + elif self.depth_mode == "geometric": + ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() + elif self.depth_mode == "harmonic": + # The original NeRF recommends this interpolation scheme for + # spherical scenes, but there could be some weird edge cases when + # the observer crosses from the inner to outer volume. + ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) + + mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) + upper = torch.cat([mids, t1], dim=-1) + lower = torch.cat([t0, mids], dim=-1) + # yiyi notes: add a random seed here for testing, don't forget to remove + torch.manual_seed(0) + t_rand = torch.rand_like(ts) + + ts = lower + (upper - lower) * t_rand + return ts.unsqueeze(-1) + + +class ImportanceRaySampler(nn.Module): + """ + Given the initial estimate of densities, this samples more from regions/bins expected to have objects. + """ + + def __init__( + self, + volume_range: VolumeRange, + ts: torch.Tensor, + weights: torch.Tensor, + blur_pool: bool = False, + alpha: float = 1e-5, + ): + """ + Args: + volume_range: the range in which a ray intersects the given volume. + ts: earlier samples from the coarse rendering step + weights: discretized version of density * transmittance + blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. + alpha: small value to add to weights. + """ + self.volume_range = volume_range + self.ts = ts.clone().detach() + self.weights = weights.clone().detach() + self.blur_pool = blur_pool + self.alpha = alpha + + @torch.no_grad() + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + lower, upper, _ = self.volume_range.partition(self.ts) + + batch_size, *shape, n_coarse_samples, _ = self.ts.shape + + weights = self.weights + if self.blur_pool: + padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) + maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) + weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) + weights = weights + self.alpha + pmf = weights / weights.sum(dim=-2, keepdim=True) + inds = sample_pmf(pmf, n_samples) + assert inds.shape == (batch_size, *shape, n_samples, 1) + assert (inds >= 0).all() and (inds < n_coarse_samples).all() + + t_rand = torch.rand(inds.shape, device=inds.device) + lower_ = torch.gather(lower, -2, inds) + upper_ = torch.gather(upper, -2, inds) + + ts = lower_ + (upper_ - lower_) * t_rand + ts = torch.sort(ts, dim=-2).values + return ts + + +@dataclass +class MeshDecoderOutput(BaseOutput): + """ + A 3D triangle mesh with optional data at the vertices and faces. + + Args: + verts (`torch.Tensor` of shape `(N, 3)`): + array of vertext coordinates + faces (`torch.Tensor` of shape `(N, 3)`): + array of triangles, pointing to indices in verts. + vertext_channels (Dict): + vertext coordinates for each color channel + """ + + verts: torch.Tensor + faces: torch.Tensor + vertex_channels: dict[str, torch.Tensor] + + +class MeshDecoder(nn.Module): + """ + Construct meshes from Signed distance functions (SDFs) using marching cubes method + """ + + def __init__(self): + super().__init__() + cases = torch.zeros(256, 5, 3, dtype=torch.long) + masks = torch.zeros(256, 5, dtype=torch.bool) + + self.register_buffer("cases", cases) + self.register_buffer("masks", masks) + + def forward(self, field: torch.Tensor, min_point: torch.Tensor, size: torch.Tensor): + """ + For a signed distance field, produce a mesh using marching cubes. + + :param field: a 3D tensor of field values, where negative values correspond + to the outside of the shape. The dimensions correspond to the x, y, and z directions, respectively. + :param min_point: a tensor of shape [3] containing the point corresponding + to (0, 0, 0) in the field. + :param size: a tensor of shape [3] containing the per-axis distance from the + (0, 0, 0) field corner and the (-1, -1, -1) field corner. + """ + assert len(field.shape) == 3, "input must be a 3D scalar field" + dev = field.device + + cases = self.cases.to(dev) + masks = self.masks.to(dev) + + min_point = min_point.to(dev) + size = size.to(dev) + + grid_size = field.shape + grid_size_tensor = torch.tensor(grid_size).to(size) + + # Create bitmasks between 0 and 255 (inclusive) indicating the state + # of the eight corners of each cube. + bitmasks = (field > 0).to(torch.uint8) + bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) + bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) + bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) + + # Compute corner coordinates across the entire grid. + corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) + corner_coords[range(grid_size[0]), :, :, 0] = torch.arange(grid_size[0], device=dev, dtype=field.dtype)[ + :, None, None + ] + corner_coords[:, range(grid_size[1]), :, 1] = torch.arange(grid_size[1], device=dev, dtype=field.dtype)[ + :, None + ] + corner_coords[:, :, range(grid_size[2]), 2] = torch.arange(grid_size[2], device=dev, dtype=field.dtype) + + # Compute all vertices across all edges in the grid, even though we will + # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. + # These are all midpoints, and don't account for interpolation (which is + # done later based on the used edge midpoints). + edge_midpoints = torch.cat( + [ + ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), + ], + dim=0, + ) + + # Create a flat array of [X, Y, Z] indices for each cube. + cube_indices = torch.zeros( + grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long + ) + cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[:, None, None] + cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[:, None] + cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) + flat_cube_indices = cube_indices.reshape(-1, 3) + + # Create a flat array mapping each cube to 12 global edge indices. + edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) + + # Apply the LUT to figure out the triangles. + flat_bitmasks = bitmasks.reshape(-1).long() # must cast to long for indexing to believe this not a mask + local_tris = cases[flat_bitmasks] + local_masks = masks[flat_bitmasks] + # Compute the global edge indices for the triangles. + global_tris = torch.gather(edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1)).reshape( + local_tris.shape + ) + # Select the used triangles for each cube. + selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] + + # Now we have a bunch of indices into the full list of possible vertices, + # but we want to reduce this list to only the used vertices. + used_vertex_indices = torch.unique(selected_tris.view(-1)) + used_edge_midpoints = edge_midpoints[used_vertex_indices] + old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) + old_index_to_new_index[used_vertex_indices] = torch.arange( + len(used_vertex_indices), device=dev, dtype=torch.long + ) + + # Rewrite the triangles to use the new indices + faces = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape(selected_tris.shape) + + # Compute the actual interpolated coordinates corresponding to edge midpoints. + v1 = torch.floor(used_edge_midpoints).to(torch.long) + v2 = torch.ceil(used_edge_midpoints).to(torch.long) + s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] + s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] + p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point + p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point + # The signs of s1 and s2 should be different. We want to find + # t such that t*s2 + (1-t)*s1 = 0. + t = (s1 / (s1 - s2))[:, None] + verts = t * p2 + (1 - t) * p1 + + return MeshDecoderOutput(verts=verts, faces=faces, vertex_channels=None) + + +@dataclass +class MLPNeRFModelOutput(BaseOutput): + density: torch.Tensor + signed_distance: torch.Tensor + channels: torch.Tensor + ts: torch.Tensor + + +class MLPNeRSTFModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + ): + super().__init__() + + # Instantiate the MLP + + # Find out the dimension of encoded position and direction + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(position=dummy).shape[-1] + d_posenc_dir = encode_direction(position=dummy).shape[-1] + + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos] + mlp_widths + output_widths = mlp_widths + [n_output] + + if insert_direction_at is not None: + input_widths[insert_direction_at] += d_posenc_dir + + self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) + + if act_fn == "swish": + # self.activation = swish + # yiyi testing: + self.activation = lambda x: F.silu(x) + else: + raise ValueError(f"Unsupported activation function {act_fn}") + + self.sdf_activation = torch.tanh + self.density_activation = torch.nn.functional.relu + self.channel_activation = torch.sigmoid + + def map_indices_to_keys(self, output): + h_map = { + "sdf": (0, 1), + "density_coarse": (1, 2), + "density_fine": (2, 3), + "stf": (3, 6), + "nerf_coarse": (6, 9), + "nerf_fine": (9, 12), + } + + mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} + + return mapped_output + + def forward(self, *, position, direction, ts, nerf_level="coarse", rendering_mode="nerf"): + h = encode_position(position) + + h_preact = h + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.config.insert_direction_at: # 4 in the config + h_directionless = h_preact + h_direction = encode_direction(position, direction=direction) + h = torch.cat([h, h_direction], dim=-1) + + h = layer(h) + + h_preact = h + + if i < len(self.mlp) - 1: + h = self.activation(h) + + h_final = h + if h_directionless is None: + h_directionless = h_preact + + activation = self.map_indices_to_keys(h_final) + + if nerf_level == "coarse": + h_density = activation["density_coarse"] + else: + h_density = activation["density_fine"] + + if rendering_mode == "nerf": + if nerf_level == "coarse": + h_channels = activation["nerf_coarse"] + else: + h_channels = activation["nerf_fine"] + + elif rendering_mode == "stf": + h_channels = activation["stf"] + + density = self.density_activation(h_density) + signed_distance = self.sdf_activation(activation["sdf"]) + channels = self.channel_activation(h_channels) + + # yiyi notes: I think signed_distance is not used + return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) + + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + vectors: int, + channels: int, + d_latent: int, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels) + self.norm = nn.LayerNorm(channels) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + h = self.norm(h) + + h = h + b_vc + return h + + +class ShapEParamsProjModel(ModelMixin, ConfigMixin): + """ + project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). + + For more details, see the original paper: + """ + + @register_to_config + def __init__( + self, + *, + param_names: tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: tuple[tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + ): + super().__init__() + + # check inputs + if len(param_names) != len(param_shapes): + raise ValueError("Must provide same number of `param_names` as `param_shapes`") + self.projections = nn.ModuleDict({}) + for k, (vectors, channels) in zip(param_names, param_shapes): + self.projections[_sanitize_name(k)] = ChannelsProj( + vectors=vectors, + channels=channels, + d_latent=d_latent, + ) + + def forward(self, x: torch.Tensor): + out = {} + start = 0 + for k, shape in zip(self.config.param_names, self.config.param_shapes): + vectors, _ = shape + end = start + vectors + x_bvd = x[:, start:end] + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + + +class ShapERenderer(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + *, + param_names: tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: tuple[tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + background: tuple[float] = ( + 255.0, + 255.0, + 255.0, + ), + ): + super().__init__() + + self.params_proj = ShapEParamsProjModel( + param_names=param_names, + param_shapes=param_shapes, + d_latent=d_latent, + ) + self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) + self.void = VoidNeRFModel(background=background, channel_scale=255.0) + self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) + self.mesh_decoder = MeshDecoder() + + @torch.no_grad() + def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): + """ + Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below + with some abuse of notations) + + C(r) := sum( + transmittance(t[i]) * integrate( + lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], + ) for i in range(len(parts)) + ) + transmittance(t[-1]) * void_model(t[-1]).channels + + where + + 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through + the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are + obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t + where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the + shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and + transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], + math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + + Args: + rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: + number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including + + :return: A tuple of + - `channels` + - A importance samplers for additional fine-grained rendering + - raw model output + """ + origin, direction = rays[..., 0, :], rays[..., 1, :] + + # Integrate over [t[i], t[i + 1]] + + # 1 Intersect the rays with the current volume and sample ts to integrate along. + vrange = self.volume.intersect(origin, direction, t0_lower=None) + ts = sampler.sample(vrange.t0, vrange.t1, n_samples) + ts = ts.to(rays.dtype) + + if prev_model_out is not None: + # Append the previous ts now before fprop because previous + # rendering used a different model and we can't reuse the output. + ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values + + batch_size, *_shape, _t0_dim = vrange.t0.shape + _, *ts_shape, _ts_dim = ts.shape + + # 2. Get the points along the ray and query the model + directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) + positions = origin.unsqueeze(-2) + ts * directions + + directions = directions.to(self.mlp.dtype) + positions = positions.to(self.mlp.dtype) + + optional_directions = directions if render_with_direction else None + + model_out = self.mlp( + position=positions, + direction=optional_directions, + ts=ts, + nerf_level="coarse" if prev_model_out is None else "fine", + ) + + # 3. Integrate the model results + channels, weights, transmittance = integrate_samples( + vrange, model_out.ts, model_out.density, model_out.channels + ) + + # 4. Clean up results that do not intersect with the volume. + transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) + channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) + # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + channels = channels + transmittance * self.void(origin) + + weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) + + return channels, weighted_sampler, model_out + + @torch.no_grad() + def decode_to_image( + self, + latents, + device, + size: int = 64, + ray_batch_size: int = 4096, + n_coarse_samples=64, + n_fine_samples=128, + ): + # project the parameters from the generated latents + projected_params = self.params_proj(latents) + + # update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # create cameras object + camera = create_pan_cameras(size) + rays = camera.camera_rays + rays = rays.to(device) + n_batches = rays.shape[1] // ray_batch_size + + coarse_sampler = StratifiedRaySampler() + + images = [] + + for idx in range(n_batches): + rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] + + # render rays with coarse, stratified samples. + _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) + # Then, render with additional importance-weighted ray samples. + channels, _, _ = self.render_rays( + rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out + ) + + images.append(channels) + + images = torch.cat(images, dim=1) + images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) + + return images + + @torch.no_grad() + def decode_to_mesh( + self, + latents, + device, + grid_size: int = 128, + query_batch_size: int = 4096, + texture_channels: tuple = ("R", "G", "B"), + ): + # 1. project the parameters from the generated latents + projected_params = self.params_proj(latents) + + # 2. update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # 3. decoding with STF rendering + # 3.1 query the SDF values at vertices along a regular 128**3 grid + + query_points = volume_query_points(self.volume, grid_size) + query_positions = query_points[None].repeat(1, 1, 1).to(device=device, dtype=self.mlp.dtype) + + fields = [] + + for idx in range(0, query_positions.shape[1], query_batch_size): + query_batch = query_positions[:, idx : idx + query_batch_size] + + model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + fields.append(model_out.signed_distance) + + # predicted SDF values + fields = torch.cat(fields, dim=1) + fields = fields.float() + + assert len(fields.shape) == 3 and fields.shape[-1] == 1, ( + f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + ) + + fields = fields.reshape(1, *([grid_size] * 3)) + + # create grid 128 x 128 x 128 + # - force a negative border around the SDFs to close off all the models. + full_grid = torch.zeros( + 1, + grid_size + 2, + grid_size + 2, + grid_size + 2, + device=fields.device, + dtype=fields.dtype, + ) + full_grid.fill_(-1.0) + full_grid[:, 1:-1, 1:-1, 1:-1] = fields + fields = full_grid + + # apply a differentiable implementation of Marching Cubes to construct meshs + raw_meshes = [] + mesh_mask = [] + + for field in fields: + raw_mesh = self.mesh_decoder(field, self.volume.bbox_min, self.volume.bbox_max - self.volume.bbox_min) + mesh_mask.append(True) + raw_meshes.append(raw_mesh) + + mesh_mask = torch.tensor(mesh_mask, device=fields.device) + max_vertices = max(len(m.verts) for m in raw_meshes) + + # 3.2. query the texture color head at each vertex of the resulting mesh. + texture_query_positions = torch.stack( + [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], + dim=0, + ) + texture_query_positions = texture_query_positions.to(device=device, dtype=self.mlp.dtype) + + textures = [] + + for idx in range(0, texture_query_positions.shape[1], query_batch_size): + query_batch = texture_query_positions[:, idx : idx + query_batch_size] + + texture_model_out = self.mlp( + position=query_batch, direction=None, ts=None, nerf_level="fine", rendering_mode="stf" + ) + textures.append(texture_model_out.channels) + + # predict texture color + textures = torch.cat(textures, dim=1) + + textures = _convert_srgb_to_linear(textures) + textures = textures.float() + + # 3.3 augment the mesh with texture data + assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), ( + f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + ) + + for m, texture in zip(raw_meshes, textures): + texture = texture[: len(m.verts)] + m.vertex_channels = dict(zip(texture_channels, texture.unbind(-1))) + + return raw_meshes[0] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84d2a2dd35007a7ee1213fd772ad86c96345bf7d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_skyreels_v2"] = ["SkyReelsV2Pipeline"] + _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"] + _import_structure["pipeline_skyreels_v2_diffusion_forcing_i2v"] = [ + "SkyReelsV2DiffusionForcingImageToVideoPipeline" + ] + _import_structure["pipeline_skyreels_v2_diffusion_forcing_v2v"] = [ + "SkyReelsV2DiffusionForcingVideoToVideoPipeline" + ] + _import_structure["pipeline_skyreels_v2_i2v"] = ["SkyReelsV2ImageToVideoPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_skyreels_v2 import SkyReelsV2Pipeline + from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline + from .pipeline_skyreels_v2_diffusion_forcing_i2v import SkyReelsV2DiffusionForcingImageToVideoPipeline + from .pipeline_skyreels_v2_diffusion_forcing_v2v import SkyReelsV2DiffusionForcingVideoToVideoPipeline + from .pipeline_skyreels_v2_i2v import SkyReelsV2ImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..dac2316362ec20c1874b514637ef1681fb7dd7d3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class SkyReelsV2PipelineOutput(BaseOutput): + r""" + Output class for SkyReelsV2 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c92608fad3b6b7ed2c23a2bf7655e66730bffb30 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -0,0 +1,610 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import regex as re +import torch +from transformers import AutoTokenizer, T5EncoderModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import SkyReelsV2LoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... SkyReelsV2Pipeline, + ... UniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + + >>> # Load the pipeline + >>> # Available models: + >>> # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers + >>> vae = AutoencoderKLWan.from_pretrained( + ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2Pipeline.from_pretrained( + ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + + >>> output = pipe( + ... prompt=prompt, + ... num_inference_steps=50, + ... height=544, + ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): + r""" + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel | UMT5EncoderModel, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `544`): + The height in pixels of the generated image. + width (`int`, defaults to `960`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `97`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length for the text encoder. + + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py new file mode 100644 index 0000000000000000000000000000000000000000..8751240a1af9ef8413500092e392b5dae3f1fccc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -0,0 +1,978 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +import re +from copy import deepcopy +from typing import Any, Callable + +import torch +from transformers import AutoTokenizer, T5EncoderModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import SkyReelsV2LoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... SkyReelsV2DiffusionForcingPipeline, + ... UniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + + >>> # Load the pipeline + >>> # Available models: + >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers + >>> vae = AutoencoderKLWan.from_pretrained( + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + + >>> output = pipe( + ... prompt=prompt, + ... num_inference_steps=30, + ... height=544, + ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) + ... causal_block_size=5, # Number of frames processed together in a causal block + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): + """ + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`UMT5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel | UMT5EncoderModel, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + overlap_history=None, + num_frames=None, + base_num_frames=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " + "Please specify a value for `overlap_history`. Recommended values are 17 or 37." + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 97, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + base_latent_num_frames: int | None = None, + video_latents: torch.Tensor | None = None, + causal_block_size: int | None = None, + overlap_history_latent_frames: int | None = None, + long_video_iter: int | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + prefix_video_latents = None + prefix_video_latents_frames = 0 + + if video_latents is not None: # long video generation at the iterations other than the first one + prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:] + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." + ) + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_frames = prefix_video_latents.shape[2] + + finished_frame_num = ( + long_video_iter * (base_latent_num_frames - overlap_history_latent_frames) + + overlap_history_latent_frames + ) + left_frame_num = num_latent_frames - finished_frame_num + num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames) + elif base_latent_num_frames is not None: # long video generation at the first iteration + num_latent_frames = base_latent_num_frames + else: # short video generation + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + latent_height, + latent_width, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames + + def generate_timestep_matrix( + self, + num_latent_frames: int, + step_template: torch.Tensor, + base_num_latent_frames: int, + ar_step: int = 5, + num_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + """ + This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule + across temporal frames. It supports both synchronous and asynchronous generation modes: + + **Synchronous Mode** (ar_step=0, causal_block_size=1): + - All frames are denoised simultaneously at each timestep + - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0] + - Simpler but may have less temporal consistency for long videos + + **Asynchronous Mode** (ar_step>0, causal_block_size>1): + - Frames are grouped into causal blocks and processed block/chunk-wise + - Each block is denoised in a staggered pattern creating a "denoising wave" + - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps + - Creates stronger temporal dependencies and better consistency + + Args: + num_latent_frames (int): Total number of latent frames to generate + step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0]) + base_num_latent_frames (int): Maximum frames the model can process in one forward pass + ar_step (int, optional): Autoregressive step size for temporal lag. + 0 = synchronous, >0 = asynchronous. Defaults to 5. + num_pre_ready (int, optional): + Number of frames already denoised (e.g., from prefix in a video2video task). + Defaults to 0. + causal_block_size (int, optional): Number of frames processed as a causal block. + Defaults to 1. + shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals. + Defaults to False. + + Returns: + tuple containing: + - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape: + [num_iterations, num_latent_frames] + - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations, + num_latent_frames] + - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape: + [num_iterations, num_latent_frames] + - valid_interval (list[tuple]): list of (start, end) intervals for each iteration + + Raises: + ValueError: If ar_step is too small for the given configuration + """ + # Initialize lists to store the scheduling matrices and metadata + step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration + update_mask, valid_interval = [], [] # Will store update masks and processing intervals + + # Calculate total number of denoising iterations (add 1 for initial noise state) + num_iterations = len(step_template) + 1 + + # Convert frame counts to block counts for causal processing + # Each block contains causal_block_size frames that are processed together + # E.g.: 25 frames ÷ 5 = 5 blocks total + num_blocks = num_latent_frames // causal_block_size + base_num_blocks = base_num_latent_frames // causal_block_size + + # Validate ar_step is sufficient for the given configuration + # In asynchronous mode, we need enough timesteps to create the staggered pattern + if base_num_blocks < num_blocks: + min_ar_step = len(step_template) / base_num_blocks + if ar_step < min_ar_step: + raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting") + + # Extend step_template with boundary values for easier indexing + # 999: dummy value for counter starting from 1 + # 0: final timestep (completely denoised) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) + + # Initialize the previous row state (tracks denoising progress for each block) + # 0 means not started, num_iterations means fully denoised + pre_row = torch.zeros(num_blocks, dtype=torch.long) + + # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state + if num_pre_ready > 0: + pre_row[: num_pre_ready // causal_block_size] = num_iterations + + # Main loop: Generate denoising schedule until all frames are fully denoised + while not torch.all(pre_row >= (num_iterations - 1)): + # Create new row representing the next denoising step + new_row = torch.zeros(num_blocks, dtype=torch.long) + + # Apply diffusion forcing logic for each block + for i in range(num_blocks): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + # Asynchronous mode: lag behind previous block by ar_step timesteps + # This creates the "diffusion forcing" staggered pattern + new_row[i] = new_row[i - 1] - ar_step + + # Clamp values to valid range [0, num_iterations] + new_row = new_row.clamp(0, num_iterations) + + # Create update mask: True for blocks that need denoising update at this iteration + # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations) + # Final state example: [False, ..., False, True, True, True, True, True] + # where first 20 frames are done (False) and last 5 frames still need updates (True) + update_mask.append((new_row != pre_row) & (new_row != num_iterations)) + + # Store the iteration state + step_index.append(new_row) # Index into step_template + step_matrix.append(step_template[new_row]) # Actual timestep values + pre_row = new_row # Update for next iteration + + # For videos longer than model capacity, we process in sliding windows + terminal_flag = base_num_blocks + + # Optional optimization: shrink interval based on first update mask + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_blocks, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + + # Each interval defines which frames to process in the current forward pass + for curr_mask in update_mask: + # Extend terminal flag if current mask has updates beyond current terminal + if terminal_flag < num_blocks and curr_mask[terminal_flag]: + terminal_flag += 1 + # Create interval: [start, end) where start ensures we don't exceed model capacity + valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag)) + + # Convert lists to tensors for efficient processing + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + # Each block's schedule is replicated to all frames within that block + if causal_block_size > 1: + # Expand each block to causal_block_size frames + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + # Scale intervals from block-level to frame-level + valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + overlap_history: int | None = None, + addnoise_condition: float = 0, + base_num_frames: int = 97, + ar_step: int = 0, + causal_block_size: int | None = None, + fps: int = 24, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `544`): + The height of the generated video. + width (`int`, defaults to `960`): + The width of the generated video. + num_frames (`int`, defaults to `97`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + overlap_history (`int`, *optional*, defaults to `None`): + Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes + short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. + addnoise_condition (`float`, *optional*, defaults to `0`): + This is used to help smooth the long video generation by adding some noise to the clean condition. Too + large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger + ones, but it is recommended to not exceed 50. + base_num_frames (`int`, *optional*, defaults to `97`): + 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) + ar_step (`int`, *optional*, defaults to `0`): + Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous + inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed + to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole + sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous + inference may improve the instruction following and visual consistent performance. + causal_block_size (`int`, *optional*, defaults to `None`): + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > + 0) + fps (`int`, *optional*, defaults to `24`): + Frame rate of the generated video + + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + overlap_history, + num_frames, + base_num_frames, + ) + + if addnoise_condition > 60: + logger.warning( + f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended." + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if causal_block_size is None: + causal_block_size = self.transformer.config.num_frame_per_block + else: + self.transformer._set_ar_attention(causal_block_size) + + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + # Determine if we're doing long video generation + is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames + # Initialize accumulated_latents to store all latents in one tensor + accumulated_latents = None + if is_long_video: + # Long video generation setup + overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_latent_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) + n_iter = ( + 1 + + (num_latent_frames - base_latent_num_frames - 1) + // (base_latent_num_frames - overlap_history_latent_frames) + + 1 + ) + else: + # Short video generation setup + n_iter = 1 + base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # Loop through iterations (multiple iterations only for long videos) + for iter_idx in range(n_iter): + if is_long_video: + logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...") + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = ( + self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents if iter_idx == 0 else None, + video_latents=accumulated_latents, # Pass latents directly instead of decoded video + base_latent_num_frames=base_latent_num_frames if is_long_video else None, + causal_block_size=causal_block_size, + overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None, + long_video_iter=iter_idx if is_long_video else None, + ) + ) + + if prefix_video_latents_frames > 0: + latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype) + + # 6. Prepare sample schedulers and timestep matrix + sample_schedulers = [] + for _ in range(current_num_latent_frames): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=device) + sample_schedulers.append(sample_scheduler) + + # Different matrix generation for short vs long video + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + current_num_latent_frames, + timesteps, + current_num_latent_frames if is_long_video else base_latent_num_frames, + ar_step, + prefix_video_latents_frames, + causal_block_size, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) + + with self.progress_bar(total=len(step_matrix)) as progress_bar: + for i, t in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] + ) + * noise_factor + ) + timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], + return_dict=False, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Handle latent accumulation for long videos or use the current latents for short videos + if is_long_video: + if accumulated_latents is None: + accumulated_latents = latents + else: + # Keep overlap frames for conditioning but don't include them in final output + accumulated_latents = torch.cat( + [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2 + ) + + if is_long_video: + latents = accumulated_latents + + self._current_timestep = None + + # Final decoding step - convert latents to pixels + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f1b3a84a4aa770dd07b30fc396db4b7adac5c3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -0,0 +1,1059 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +import re +from copy import deepcopy +from typing import Any, Callable + +import PIL +import torch +from transformers import AutoTokenizer, T5EncoderModel, UMT5EncoderModel + +from diffusers.image_processor import PipelineImageInput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import SkyReelsV2LoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... SkyReelsV2DiffusionForcingImageToVideoPipeline, + ... UniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + >>> from PIL import Image + + >>> # Load the pipeline + >>> # Available models: + >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers + >>> vae = AutoencoderKLWan.from_pretrained( + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> image = Image.open("path/to/image.png") + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... num_inference_steps=50, + ... height=544, + ... width=960, + ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode) + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): + """ + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`UMT5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel | UMT5EncoderModel, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + overlap_history=None, + num_frames=None, + base_num_frames=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " + "Please specify a value for `overlap_history`. Recommended values are 17 or 37." + ) + + def prepare_latents( + self, + image: PipelineImageInput | None, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 97, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + video_latents: torch.Tensor | None = None, + base_latent_num_frames: int | None = None, + causal_block_size: int | None = None, + overlap_history_latent_frames: int | None = None, + long_video_iter: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + prefix_video_latents_frames = 0 + + if video_latents is not None: # long video generation at the iterations other than the first one + condition = video_latents[:, :, -overlap_history_latent_frames:] + + if condition.shape[2] % causal_block_size != 0: + truncate_len_latents = condition.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." + ) + condition = condition[:, :, :-truncate_len_latents] + prefix_video_latents_frames = condition.shape[2] + + finished_frame_num = ( + long_video_iter * (base_latent_num_frames - overlap_history_latent_frames) + + overlap_history_latent_frames + ) + left_frame_num = num_latent_frames - finished_frame_num + num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames) + elif base_latent_num_frames is not None: # long video generation at the first iteration + num_latent_frames = base_latent_num_frames + else: # short video generation + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + if image is not None: + image = image.unsqueeze(2) + if last_image is not None: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat([image, last_image], dim=0) + else: + video_condition = image + + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat_interleave(batch_size, dim=0) + + latent_condition = latent_condition.to(dtype) + condition = (latent_condition - latents_mean) * latents_std + prefix_video_latents_frames = condition.shape[2] + + return latents, num_latent_frames, condition, prefix_video_latents_frames + + # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix + def generate_timestep_matrix( + self, + num_latent_frames: int, + step_template: torch.Tensor, + base_num_latent_frames: int, + ar_step: int = 5, + num_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + """ + This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule + across temporal frames. It supports both synchronous and asynchronous generation modes: + + **Synchronous Mode** (ar_step=0, causal_block_size=1): + - All frames are denoised simultaneously at each timestep + - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0] + - Simpler but may have less temporal consistency for long videos + + **Asynchronous Mode** (ar_step>0, causal_block_size>1): + - Frames are grouped into causal blocks and processed block/chunk-wise + - Each block is denoised in a staggered pattern creating a "denoising wave" + - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps + - Creates stronger temporal dependencies and better consistency + + Args: + num_latent_frames (int): Total number of latent frames to generate + step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0]) + base_num_latent_frames (int): Maximum frames the model can process in one forward pass + ar_step (int, optional): Autoregressive step size for temporal lag. + 0 = synchronous, >0 = asynchronous. Defaults to 5. + num_pre_ready (int, optional): + Number of frames already denoised (e.g., from prefix in a video2video task). + Defaults to 0. + causal_block_size (int, optional): Number of frames processed as a causal block. + Defaults to 1. + shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals. + Defaults to False. + + Returns: + tuple containing: + - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape: + [num_iterations, num_latent_frames] + - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations, + num_latent_frames] + - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape: + [num_iterations, num_latent_frames] + - valid_interval (list[tuple]): list of (start, end) intervals for each iteration + + Raises: + ValueError: If ar_step is too small for the given configuration + """ + # Initialize lists to store the scheduling matrices and metadata + step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration + update_mask, valid_interval = [], [] # Will store update masks and processing intervals + + # Calculate total number of denoising iterations (add 1 for initial noise state) + num_iterations = len(step_template) + 1 + + # Convert frame counts to block counts for causal processing + # Each block contains causal_block_size frames that are processed together + # E.g.: 25 frames ÷ 5 = 5 blocks total + num_blocks = num_latent_frames // causal_block_size + base_num_blocks = base_num_latent_frames // causal_block_size + + # Validate ar_step is sufficient for the given configuration + # In asynchronous mode, we need enough timesteps to create the staggered pattern + if base_num_blocks < num_blocks: + min_ar_step = len(step_template) / base_num_blocks + if ar_step < min_ar_step: + raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting") + + # Extend step_template with boundary values for easier indexing + # 999: dummy value for counter starting from 1 + # 0: final timestep (completely denoised) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) + + # Initialize the previous row state (tracks denoising progress for each block) + # 0 means not started, num_iterations means fully denoised + pre_row = torch.zeros(num_blocks, dtype=torch.long) + + # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state + if num_pre_ready > 0: + pre_row[: num_pre_ready // causal_block_size] = num_iterations + + # Main loop: Generate denoising schedule until all frames are fully denoised + while not torch.all(pre_row >= (num_iterations - 1)): + # Create new row representing the next denoising step + new_row = torch.zeros(num_blocks, dtype=torch.long) + + # Apply diffusion forcing logic for each block + for i in range(num_blocks): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + # Asynchronous mode: lag behind previous block by ar_step timesteps + # This creates the "diffusion forcing" staggered pattern + new_row[i] = new_row[i - 1] - ar_step + + # Clamp values to valid range [0, num_iterations] + new_row = new_row.clamp(0, num_iterations) + + # Create update mask: True for blocks that need denoising update at this iteration + # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations) + # Final state example: [False, ..., False, True, True, True, True, True] + # where first 20 frames are done (False) and last 5 frames still need updates (True) + update_mask.append((new_row != pre_row) & (new_row != num_iterations)) + + # Store the iteration state + step_index.append(new_row) # Index into step_template + step_matrix.append(step_template[new_row]) # Actual timestep values + pre_row = new_row # Update for next iteration + + # For videos longer than model capacity, we process in sliding windows + terminal_flag = base_num_blocks + + # Optional optimization: shrink interval based on first update mask + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_blocks, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + + # Each interval defines which frames to process in the current forward pass + for curr_mask in update_mask: + # Extend terminal flag if current mask has updates beyond current terminal + if terminal_flag < num_blocks and curr_mask[terminal_flag]: + terminal_flag += 1 + # Create interval: [start, end) where start ensures we don't exceed model capacity + valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag)) + + # Convert lists to tensors for efficient processing + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + # Each block's schedule is replicated to all frames within that block + if causal_block_size > 1: + # Expand each block to causal_block_size frames + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + # Scale intervals from block-level to frame-level + valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + overlap_history: int | None = None, + addnoise_condition: float = 0, + base_num_frames: int = 97, + ar_step: int = 0, + causal_block_size: int | None = None, + fps: int = 24, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `544`): + The height of the generated video. + width (`int`, defaults to `960`): + The width of the generated video. + num_frames (`int`, defaults to `97`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + last_image (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + overlap_history (`int`, *optional*, defaults to `None`): + Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes + short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. + addnoise_condition (`float`, *optional*, defaults to `0`): + This is used to help smooth the long video generation by adding some noise to the clean condition. Too + large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger + ones, but it is recommended to not exceed 50. + base_num_frames (`int`, *optional*, defaults to `97`): + 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) + ar_step (`int`, *optional*, defaults to `0`): + Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous + inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed + to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole + sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous + inference may improve the instruction following and visual consistent performance. + causal_block_size (`int`, *optional*, defaults to `None`): + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > + 0) + fps (`int`, *optional*, defaults to `24`): + Frame rate of the generated video + + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + overlap_history, + num_frames, + base_num_frames, + ) + + if addnoise_condition > 60: + logger.warning( + f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended." + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if causal_block_size is None: + causal_block_size = self.transformer.config.num_frame_per_block + else: + self.transformer._set_ar_attention(causal_block_size) + + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + # Determine if we're doing long video generation + is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames + # Initialize accumulated_latents to store all latents in one tensor + accumulated_latents = None + if is_long_video: + # Long video generation setup + overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_latent_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) + n_iter = ( + 1 + + (num_latent_frames - base_latent_num_frames - 1) + // (base_latent_num_frames - overlap_history_latent_frames) + + 1 + ) + else: + # Short video generation setup + n_iter = 1 + base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # Loop through iterations (multiple iterations only for long videos) + for iter_idx in range(n_iter): + if is_long_video: + logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...") + + num_channels_latents = self.vae.config.z_dim + latents, current_num_latent_frames, condition, prefix_video_latents_frames = self.prepare_latents( + image if iter_idx == 0 else None, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents if iter_idx == 0 else None, + last_image, + video_latents=accumulated_latents, # Pass latents directly instead of decoded video + base_latent_num_frames=base_latent_num_frames if is_long_video else None, + causal_block_size=causal_block_size, + overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None, + long_video_iter=iter_idx if is_long_video else None, + ) + + if iter_idx == 0: + latents[:, :, :prefix_video_latents_frames, :, :] = condition[: (condition.shape[0] + 1) // 2].to( + transformer_dtype + ) + else: + latents[:, :, :prefix_video_latents_frames, :, :] = condition.to(transformer_dtype) + + if iter_idx == 0 and last_image is not None: + end_video_latents = condition[condition.shape[0] // 2 :].to(transformer_dtype) + + if last_image is not None and iter_idx + 1 == n_iter: + latents = torch.cat([latents, end_video_latents], dim=2) + base_latent_num_frames += prefix_video_latents_frames + current_num_latent_frames += prefix_video_latents_frames + + # 4. Prepare sample schedulers and timestep matrix + sample_schedulers = [] + for _ in range(current_num_latent_frames): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=device) + sample_schedulers.append(sample_scheduler) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + current_num_latent_frames, + timesteps, + base_latent_num_frames, + ar_step, + prefix_video_latents_frames, + causal_block_size, + ) + + if last_image is not None and iter_idx + 1 == n_iter: + step_matrix[:, -prefix_video_latents_frames:] = 0 + step_update_mask[:, -prefix_video_latents_frames:] = False + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) + + with self.progress_bar(total=len(step_matrix)) as progress_bar: + for i, t in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] + ) + * noise_factor + ) + timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], + return_dict=False, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # Handle latent accumulation for long videos or use the current latents for short videos + if is_long_video: + if accumulated_latents is None: + accumulated_latents = latents + else: + # Keep overlap frames for conditioning but don't include them in final output + accumulated_latents = torch.cat( + [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], + dim=2, + ) + + if is_long_video: + latents = accumulated_latents + + self._current_timestep = None + + # Final decoding step - convert latents to pixels + if not output_type == "latent": + if last_image is not None: + latents = latents[:, :, :-prefix_video_latents_frames, :, :].to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py new file mode 100644 index 0000000000000000000000000000000000000000..924acb850d0908557a1919edbc305a1957ddac0a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -0,0 +1,1063 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +from copy import deepcopy +from typing import Any, Callable + +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import SkyReelsV2LoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... SkyReelsV2DiffusionForcingVideoToVideoPipeline, + ... UniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + + >>> # Load the pipeline + >>> # Available models: + >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers + >>> vae = AutoencoderKLWan.from_pretrained( + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + + >>> output = pipe( + ... prompt=prompt, + ... num_inference_steps=50, + ... height=544, + ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode) + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): + """ + Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`UMT5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel | UMT5EncoderModel, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + overlap_history=None, + num_frames=None, + base_num_frames=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " + "Please specify a value for `overlap_history`. Recommended values are 17 or 37." + ) + + def prepare_latents( + self, + video: torch.Tensor, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 97, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + video_latents: torch.Tensor | None = None, + base_latent_num_frames: int | None = None, + overlap_history: int | None = None, + causal_block_size: int | None = None, + overlap_history_latent_frames: int | None = None, + long_video_iter: int | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = ( + (num_frames - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.shape[2] + ) + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + if long_video_iter == 0: + prefix_video_latents = [ + retrieve_latents( + self.vae.encode( + vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:] + ), + sample_mode="argmax", + ) + for vid in video + ] + prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std + else: + prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:] + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." + ) + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_frames = prefix_video_latents.shape[2] + + finished_frame_num = ( + long_video_iter * (base_latent_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames + ) + left_frame_num = num_latent_frames - finished_frame_num + num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames) + + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + latent_height, + latent_width, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames + + # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix + def generate_timestep_matrix( + self, + num_latent_frames: int, + step_template: torch.Tensor, + base_num_latent_frames: int, + ar_step: int = 5, + num_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + """ + This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule + across temporal frames. It supports both synchronous and asynchronous generation modes: + + **Synchronous Mode** (ar_step=0, causal_block_size=1): + - All frames are denoised simultaneously at each timestep + - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0] + - Simpler but may have less temporal consistency for long videos + + **Asynchronous Mode** (ar_step>0, causal_block_size>1): + - Frames are grouped into causal blocks and processed block/chunk-wise + - Each block is denoised in a staggered pattern creating a "denoising wave" + - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps + - Creates stronger temporal dependencies and better consistency + + Args: + num_latent_frames (int): Total number of latent frames to generate + step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0]) + base_num_latent_frames (int): Maximum frames the model can process in one forward pass + ar_step (int, optional): Autoregressive step size for temporal lag. + 0 = synchronous, >0 = asynchronous. Defaults to 5. + num_pre_ready (int, optional): + Number of frames already denoised (e.g., from prefix in a video2video task). + Defaults to 0. + causal_block_size (int, optional): Number of frames processed as a causal block. + Defaults to 1. + shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals. + Defaults to False. + + Returns: + tuple containing: + - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape: + [num_iterations, num_latent_frames] + - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations, + num_latent_frames] + - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape: + [num_iterations, num_latent_frames] + - valid_interval (list[tuple]): list of (start, end) intervals for each iteration + + Raises: + ValueError: If ar_step is too small for the given configuration + """ + # Initialize lists to store the scheduling matrices and metadata + step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration + update_mask, valid_interval = [], [] # Will store update masks and processing intervals + + # Calculate total number of denoising iterations (add 1 for initial noise state) + num_iterations = len(step_template) + 1 + + # Convert frame counts to block counts for causal processing + # Each block contains causal_block_size frames that are processed together + # E.g.: 25 frames ÷ 5 = 5 blocks total + num_blocks = num_latent_frames // causal_block_size + base_num_blocks = base_num_latent_frames // causal_block_size + + # Validate ar_step is sufficient for the given configuration + # In asynchronous mode, we need enough timesteps to create the staggered pattern + if base_num_blocks < num_blocks: + min_ar_step = len(step_template) / base_num_blocks + if ar_step < min_ar_step: + raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting") + + # Extend step_template with boundary values for easier indexing + # 999: dummy value for counter starting from 1 + # 0: final timestep (completely denoised) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) + + # Initialize the previous row state (tracks denoising progress for each block) + # 0 means not started, num_iterations means fully denoised + pre_row = torch.zeros(num_blocks, dtype=torch.long) + + # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state + if num_pre_ready > 0: + pre_row[: num_pre_ready // causal_block_size] = num_iterations + + # Main loop: Generate denoising schedule until all frames are fully denoised + while not torch.all(pre_row >= (num_iterations - 1)): + # Create new row representing the next denoising step + new_row = torch.zeros(num_blocks, dtype=torch.long) + + # Apply diffusion forcing logic for each block + for i in range(num_blocks): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + # Asynchronous mode: lag behind previous block by ar_step timesteps + # This creates the "diffusion forcing" staggered pattern + new_row[i] = new_row[i - 1] - ar_step + + # Clamp values to valid range [0, num_iterations] + new_row = new_row.clamp(0, num_iterations) + + # Create update mask: True for blocks that need denoising update at this iteration + # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations) + # Final state example: [False, ..., False, True, True, True, True, True] + # where first 20 frames are done (False) and last 5 frames still need updates (True) + update_mask.append((new_row != pre_row) & (new_row != num_iterations)) + + # Store the iteration state + step_index.append(new_row) # Index into step_template + step_matrix.append(step_template[new_row]) # Actual timestep values + pre_row = new_row # Update for next iteration + + # For videos longer than model capacity, we process in sliding windows + terminal_flag = base_num_blocks + + # Optional optimization: shrink interval based on first update mask + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_blocks, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + + # Each interval defines which frames to process in the current forward pass + for curr_mask in update_mask: + # Extend terminal flag if current mask has updates beyond current terminal + if terminal_flag < num_blocks and curr_mask[terminal_flag]: + terminal_flag += 1 + # Create interval: [start, end) where start ensures we don't exceed model capacity + valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag)) + + # Convert lists to tensors for efficient processing + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + # Each block's schedule is replicated to all frames within that block + if causal_block_size > 1: + # Expand each block to causal_block_size frames + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + # Scale intervals from block-level to frame-level + valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: list[Image.Image], + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 120, + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + overlap_history: int | None = None, + addnoise_condition: float = 0, + base_num_frames: int = 97, + ar_step: int = 0, + causal_block_size: int | None = None, + fps: int = 24, + ): + r""" + The call function to the pipeline for generation. + + Args: + video (`list[Image.Image]`): + The video to guide the video generation. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `544`): + The height of the generated video. + width (`int`, defaults to `960`): + The width of the generated video. + num_frames (`int`, defaults to `120`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `6.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + overlap_history (`int`, *optional*, defaults to `None`): + Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes + short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. + addnoise_condition (`float`, *optional*, defaults to `0`): + This is used to help smooth the long video generation by adding some noise to the clean condition. Too + large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger + ones, but it is recommended to not exceed 50. + base_num_frames (`int`, *optional*, defaults to `97`): + 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) + ar_step (`int`, *optional*, defaults to `0`): + Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous + inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed + to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole + sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous + inference may improve the instruction following and visual consistent performance. + causal_block_size (`int`, *optional*, defaults to `None`): + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > + 0) + fps (`int`, *optional*, defaults to `24`): + Frame rate of the generated video + + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + video, + latents, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + overlap_history, + num_frames, + base_num_frames, + ) + + if addnoise_condition > 60: + logger.warning( + f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended." + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if latents is None: + video_original = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + if causal_block_size is None: + causal_block_size = self.transformer.config.num_frame_per_block + else: + self.transformer._set_ar_attention(causal_block_size) + + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + # Long video generation + accumulated_latents = None + overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_latent_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) + n_iter = ( + 1 + + (num_latent_frames - base_latent_num_frames - 1) + // (base_latent_num_frames - overlap_history_latent_frames) + + 1 + ) + for long_video_iter in range(n_iter): + logger.debug(f"Processing iteration {long_video_iter + 1}/{n_iter} for long video generation...") + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = ( + self.prepare_latents( + video_original, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents if long_video_iter == 0 else None, + video_latents=accumulated_latents, # Pass latents directly instead of decoded video + overlap_history=overlap_history, + base_latent_num_frames=base_latent_num_frames, + causal_block_size=causal_block_size, + overlap_history_latent_frames=overlap_history_latent_frames, + long_video_iter=long_video_iter, + ) + ) + + if prefix_video_latents_frames > 0: + latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype) + + # 4. Prepare sample schedulers and timestep matrix + sample_schedulers = [] + for _ in range(current_num_latent_frames): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=device) + sample_schedulers.append(sample_scheduler) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + current_num_latent_frames, + timesteps, + current_num_latent_frames, + ar_step, + prefix_video_latents_frames, + causal_block_size, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) + + with self.progress_bar(total=len(step_matrix)) as progress_bar: + for i, t in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] + ) + * noise_factor + ) + timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], + return_dict=False, + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if accumulated_latents is None: + accumulated_latents = latents + else: + # Keep overlap frames for conditioning but don't include them in final output + accumulated_latents = torch.cat( + [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2 + ) + + latents = accumulated_latents + + self._current_timestep = None + + # Final decoding step - convert latents to pixels + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video_generated = self.vae.decode(latents, return_dict=False)[0] + video = torch.cat([video_original, video_generated], dim=2) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..7c24b898e0bb0de16bdb866a67007ce55a5f8d30 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -0,0 +1,745 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPProcessor, CLIPVisionModelWithProjection, T5EncoderModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import SkyReelsV2LoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> from diffusers import ( + ... SkyReelsV2ImageToVideoPipeline, + ... UniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + >>> from PIL import Image + + >>> # Load the pipeline + >>> # Available models: + >>> # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers + >>> vae = AutoencoderKLWan.from_pretrained( + ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( + ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> image = Image.open("path/to/image.png") + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... num_inference_steps=50, + ... height=544, + ... width=960, + ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): + r""" + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: T5EncoderModel | UMT5EncoderModel, + image_encoder: CLIPVisionModelWithProjection, + image_processor: CLIPProcessor, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: torch.device | None = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 544, + width: int = 960, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `544`): + The height of the generated video. + width (`int`, defaults to `960`): + The width of the generated video. + num_frames (`int`, defaults to `97`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdd419ae9914e64c7fdcf7c152ac308b75d75d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"] + _import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_stable_audio import StableAudioProjectionModel + from .pipeline_stable_audio import StableAudioPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/modeling_stable_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..d40269411bc01503be84c6744d5605efb32ab2eb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/modeling_stable_audio.py @@ -0,0 +1,156 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import pi + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, times: torch.Tensor) -> torch.Tensor: + times = times[..., None] + freqs = times * self.weights[None] * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((times, fouriered), dim=-1) + return fouriered + + +@dataclass +class StableAudioProjectionModelOutput(BaseOutput): + """ + Args: + Class for StableAudio projection layer's outputs. + text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. + seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio start hidden states. + seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio end hidden states. + """ + + text_hidden_states: torch.Tensor | None = None + seconds_start_hidden_states: torch.Tensor | None = None + seconds_end_hidden_states: torch.Tensor | None = None + + +class StableAudioNumberConditioner(nn.Module): + """ + A simple linear projection model to map numbers to a latent space. + + Args: + number_embedding_dim (`int`): + Dimensionality of the number embeddings. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + internal_dim (`int`): + Dimensionality of the intermediate number hidden states. + """ + + def __init__( + self, + number_embedding_dim, + min_value, + max_value, + internal_dim: int | None = 256, + ): + super().__init__() + self.time_positional_embedding = nn.Sequential( + StableAudioPositionalEmbedding(internal_dim), + nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), + ) + + self.number_embedding_dim = number_embedding_dim + self.min_value = min_value + self.max_value = max_value + + def forward( + self, + floats: torch.Tensor, + ): + floats = floats.clamp(self.min_value, self.max_value) + + normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) + + # Cast floats to same type as embedder + embedder_dtype = next(self.time_positional_embedding.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + embedding = self.time_positional_embedding(normalized_floats) + float_embeds = embedding.view(-1, 1, self.number_embedding_dim) + + return float_embeds + + +class StableAudioProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map the conditioning values to a shared latent space. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the text encoder (T5). + conditioning_dim (`int`): + Dimensionality of the output conditioning tensors. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + """ + + @register_to_config + def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): + super().__init__() + self.text_projection = ( + nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) + ) + self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + + def forward( + self, + text_hidden_states: torch.Tensor | None = None, + start_seconds: torch.Tensor | None = None, + end_seconds: torch.Tensor | None = None, + ): + text_hidden_states = ( + text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) + ) + seconds_start_hidden_states = ( + start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) + ) + seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) + + return StableAudioProjectionModelOutput( + text_hidden_states=text_hidden_states, + seconds_start_hidden_states=seconds_start_hidden_states, + seconds_end_hidden_states=seconds_end_hidden_states, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/pipeline_stable_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..351c8b65de0ebea9d322edd2b71ab28295f91f7a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -0,0 +1,764 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from ...models import AutoencoderOobleck, StableAudioDiTModel +from ...models.embeddings import get_1d_rotary_pos_embed +from ...schedulers import EDMDPMSolverMultistepScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .modeling_stable_audio import StableAudioProjectionModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "stabilityai/stable-audio-open-1.0" + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`EDMDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: T5Tokenizer | T5TokenizerFast, + transformer: StableAudioDiTModel, + scheduler: EDMDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + negative_attention_mask: torch.LongTensor | None = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: list[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + audio_end_in_s: float | None = None, + audio_start_in_s: float | None = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + num_waveforms_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + initial_audio_waveforms: torch.Tensor | None = None, + initial_audio_sampling_rate: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + attention_mask: torch.LongTensor | None = None, + negative_attention_mask: torch.LongTensor | None = None, + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + output_type: str | None = "pt", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5270cb94af01fd94ed6c8e76c243a86cad8ec348 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"] + _import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"] + _import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_cascade import StableCascadeDecoderPipeline + from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline + from .pipeline_stable_cascade_prior import StableCascadePriorPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6d1b843aa0b11a04cec38a19d731caeb4f1c5d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67e210c038987fbbece7d21b2b8f6e49a64eed2b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade_combined.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade_combined.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..791d7cccafd825b05e6373e523258dcbb2e9214f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade_combined.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade_prior.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade_prior.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e2d21b212855111bc4680723a661409485b17e4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/__pycache__/pipeline_stable_cascade_prior.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py new file mode 100644 index 0000000000000000000000000000000000000000..ef40078bfbb907a5c3629f9768a4cddeec3e7054 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -0,0 +1,541 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline + + >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( + ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 + ... ).to("cuda") + >>> gen_pipe = StableCascadeDecoderPipeline.from_pretrain( + ... "stabilityai/stable-cascade", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) + ``` +""" + + +class StableCascadeDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for generating images from the Stable Cascade model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModelWithProjection`): + The CLIP text encoder. + decoder ([`StableCascadeUNet`]): + The Stable Cascade decoder unet. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + """ + + _last_supported_version = "0.35.2" + + unet_name = "decoder" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "text_encoder->decoder->vqgan" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_pooled", + "negative_prompt_embeds", + "image_embeddings", + ] + + def __init__( + self, + decoder: StableCascadeUNet, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + latent_dim_scale: float = 10.67, + ) -> None: + super().__init__() + self.register_modules( + decoder=decoder, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, + vqgan=vqgan, + ) + self.register_to_config(latent_dim_scale=latent_dim_scale) + + def prepare_latents( + self, batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, scheduler + ): + _, channels, height, width = image_embeddings.shape + latents_shape = ( + batch_size * num_images_per_prompt, + 4, + int(height * self.config.latent_dim_scale), + int(width * self.config.latent_dim_scale), + ) + + if latents is None: + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, + ): + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True + ) + prompt_embeds = text_encoder_output.hidden_states[-1] + if prompt_embeds_pooled is None: + prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] + negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + seq_len = negative_prompt_embeds_pooled.shape[1] + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( + dtype=self.text_encoder.dtype, device=device + ) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled + + def check_inputs( + self, + prompt, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def get_timestep_ratio_conditioning(self, t, alphas_cumprod): + s = torch.tensor([0.008]) + clamp_range = [0, 1] + min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 + var = alphas_cumprod[t] + var = var.clamp(*clamp_range) + s, min_var = s.to(var.device), min_var.to(var.device) + ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return ratio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeddings: torch.Tensor | list[torch.Tensor], + prompt: str | list[str] = None, + num_inference_steps: int = 10, + guidance_scale: float = 0.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embedding (`torch.Tensor` or `list[torch.Tensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + # 0. Define commonly used variables + device = self._execution_device + dtype = self.decoder.dtype + self._guidance_scale = guidance_scale + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Compute the effective number of images per prompt + # We must account for the fact that the image embeddings from the prior can be generated with num_images_per_prompt > 1 + # This results in a case where a single prompt is associated with multiple image embeddings + # Divide the number of image embeddings by the batch size to determine if this is the case. + num_images_per_prompt = num_images_per_prompt * (image_embeddings.shape[0] // batch_size) + + # 2. Encode caption + if prompt_embeds is None and negative_prompt_embeds is None: + _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) + + # The pooled embeds from the prior are pooled again before being passed to the decoder + prompt_embeds_pooled = ( + torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) + if self.do_classifier_free_guidance + else prompt_embeds_pooled + ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents( + batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + ) + + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + else: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: + self.scheduler.config.clip_sample = False # disample sample clipping + logger.warning(" set `clip_sample` to be False") + + # 6. Run denoising loop + if hasattr(self.scheduler, "betas"): + alphas = 1.0 - self.scheduler.betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + else: + alphas_cumprod = [] + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + if len(alphas_cumprod) > 0: + timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) + else: + timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) + else: + timestep_ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise latents + predicted_latents = self.decoder( + sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, + clip_text_pooled=prompt_embeds_pooled, + effnet=effnet, + return_dict=False, + )[0] + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) + + # 9. Renoise latents to next timestep + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + timestep_ratio = t + latents = self.scheduler.step( + model_output=predicted_latents, + timestep=timestep_ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # 10. Scale and decode the image latents with vq-vae + latents = self.vqgan.config.scale_factor * latents + images = self.vqgan.decode(latents).sample.clamp(0, 1) + if output_type == "np": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesn't work + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesn't work + images = self.numpy_to_pil(images) + else: + images = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return images + return ImagePipelineOutput(images) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..50e6c02b6017f4ba2b85ba8333b2c4c2b4286631 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -0,0 +1,321 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import is_torch_version, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline +from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel +from .pipeline_stable_cascade import StableCascadeDecoderPipeline +from .pipeline_stable_cascade_prior import StableCascadePriorPipeline + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadeCombinedPipeline + + >>> pipe = StableCascadeCombinedPipeline.from_pretrained( + ... "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.enable_model_cpu_offload() + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) + ``` +""" + + +class StableCascadeCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Stable Cascade. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (`CLIPTextModelWithProjection`): + The decoder text encoder to be used for text inputs. + decoder (`StableCascadeUNet`): + The decoder model to be used for decoder image generation pipeline. + scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (`PaellaVQModel`): + The VQGAN model to be used for decoder image generation pipeline. + prior_prior (`StableCascadeUNet`): + The prior model to be used for prior pipeline. + prior_text_encoder (`CLIPTextModelWithProjection`): + The prior text encoder to be used for text inputs. + prior_tokenizer (`CLIPTokenizer`): + The prior tokenizer to be used for text inputs. + prior_scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. + prior_feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + """ + + _last_supported_version = "0.35.2" + + _load_connected_pipes = True + _optional_components = ["prior_feature_extractor", "prior_image_encoder"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + decoder: StableCascadeUNet, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + prior_prior: StableCascadeUNet, + prior_text_encoder: CLIPTextModelWithProjection, + prior_tokenizer: CLIPTokenizer, + prior_scheduler: DDPMWuerstchenScheduler, + prior_feature_extractor: CLIPImageProcessor | None = None, + prior_image_encoder: CLIPVisionModelWithProjection | None = None, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_prior=prior_prior, + prior_scheduler=prior_scheduler, + prior_feature_extractor=prior_feature_extractor, + prior_image_encoder=prior_image_encoder, + ) + self.prior_pipe = StableCascadePriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + image_encoder=prior_image_encoder, + feature_extractor=prior_feature_extractor, + ) + self.decoder_pipe = StableCascadeDecoderPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None, + height: int = 512, + width: int = 512, + prior_num_inference_steps: int = 60, + prior_guidance_scale: float = 4.0, + num_inference_steps: int = 12, + decoder_guidance_scale: float = 0.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Callable[[int, int], None] | None = None, + prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation for the prior and decoder. + images (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, *optional*): + The images to guide the image generation for the prior. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int | dict[float, int]`, *optional*, defaults to 60): + The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): + The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. For more specific timestep spacing, you can pass customized + `timesteps` + decoder_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + dtype = self.decoder_pipe.decoder.dtype + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError( + "`StableCascadeCombinedPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype." + ) + + prior_outputs = self.prior_pipe( + prompt=prompt if prompt_embeds is None else None, + images=images, + height=height, + width=width, + num_inference_steps=prior_num_inference_steps, + guidance_scale=prior_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + output_type="pt", + return_dict=True, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + ) + image_embeddings = prior_outputs.image_embeddings + prompt_embeds = prior_outputs.get("prompt_embeds", None) + prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None) + negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None) + negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None) + + outputs = self.decoder_pipe( + image_embeddings=image_embeddings, + prompt=prompt if prompt_embeds is None else None, + num_inference_steps=num_inference_steps, + guidance_scale=decoder_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5ea9ed61b46b0597a960f4b2ec5474846760b2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -0,0 +1,652 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import ceil +from typing import Callable + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import StableCascadeUNet +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableCascadePriorPipeline + + >>> prior_pipe = StableCascadePriorPipeline.from_pretrained( + ... "stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + ``` +""" + + +@dataclass +class StableCascadePriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeddings (`torch.Tensor` or `np.ndarray`) + Prior image embeddings for text prompt + prompt_embeds (`torch.Tensor`): + Text embeddings for the prompt. + negative_prompt_embeds (`torch.Tensor`): + Text embeddings for the negative prompt. + """ + + image_embeddings: torch.Tensor | np.ndarray + prompt_embeds: torch.Tensor | np.ndarray + prompt_embeds_pooled: torch.Tensor | np.ndarray + negative_prompt_embeds: torch.Tensor | np.ndarray + negative_prompt_embeds_pooled: torch.Tensor | np.ndarray + + +class StableCascadePriorPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for generating image prior for Stable Cascade. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`StableCascadeUNet`]): + The Stable Cascade prior to approximate the image embedding from the text and/or image embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + resolution_multiple ('float', *optional*, defaults to 42.67): + Default resolution for multiple images generated. + """ + + _last_supported_version = "0.35.2" + + unet_name = "prior" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "image_encoder->text_encoder->prior" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + prior: StableCascadeUNet, + scheduler: DDPMWuerstchenScheduler, + resolution_multiple: float = 42.67, + feature_extractor: CLIPImageProcessor | None = None, + image_encoder: CLIPVisionModelWithProjection | None = None, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + prior=prior, + scheduler=scheduler, + ) + self.register_to_config(resolution_multiple=resolution_multiple) + + def prepare_latents( + self, batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, scheduler + ): + latent_shape = ( + num_images_per_prompt * batch_size, + self.prior.config.in_channels, + ceil(height / self.config.resolution_multiple), + ceil(width / self.config.resolution_multiple), + ) + + if latents is None: + latents = randn_tensor(latent_shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != latent_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latent_shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, + ): + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True + ) + prompt_embeds = text_encoder_output.hidden_states[-1] + if prompt_embeds_pooled is None: + prompt_embeds_pooled = text_encoder_output.text_embeds.unsqueeze(1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds_pooled = prompt_embeds_pooled.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + prompt_embeds_pooled = prompt_embeds_pooled.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.hidden_states[-1] + negative_prompt_embeds_pooled = negative_prompt_embeds_text_encoder_output.text_embeds.unsqueeze(1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + seq_len = negative_prompt_embeds_pooled.shape[1] + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.to( + dtype=self.text_encoder.dtype, device=device + ) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds_pooled = negative_prompt_embeds_pooled.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + return prompt_embeds, prompt_embeds_pooled, negative_prompt_embeds, negative_prompt_embeds_pooled + + def encode_image(self, images, device, dtype, batch_size, num_images_per_prompt): + image_embeds = [] + for image in images: + image = self.feature_extractor(image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=dtype) + image_embed = self.image_encoder(image).image_embeds.unsqueeze(1) + image_embeds.append(image_embed) + image_embeds = torch.cat(image_embeds, dim=1) + + image_embeds = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + negative_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, negative_image_embeds + + def check_inputs( + self, + prompt, + images=None, + image_embeds=None, + negative_prompt=None, + prompt_embeds=None, + prompt_embeds_pooled=None, + negative_prompt_embeds=None, + negative_prompt_embeds_pooled=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and prompt_embeds_pooled is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`" + ) + + if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: + if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: + raise ValueError( + "`prompt_embeds_pooled` and `negative_prompt_embeds_pooled` must have the same shape when passed" + f"directly, but got: `prompt_embeds_pooled` {prompt_embeds_pooled.shape} !=" + f"`negative_prompt_embeds_pooled` {negative_prompt_embeds_pooled.shape}." + ) + + if image_embeds is not None and images is not None: + raise ValueError( + f"Cannot forward both `images`: {images} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + + if images: + for i, image in enumerate(images): + if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise TypeError( + f"'images' must contain images of type 'torch.Tensor' or 'PIL.Image.Image, but got" + f"{type(image)} for image number {i}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + def get_timestep_ratio_conditioning(self, t, alphas_cumprod): + s = torch.tensor([0.008]) + clamp_range = [0, 1] + min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 + var = alphas_cumprod[t] + var = var.clamp(*clamp_range) + s, min_var = s.to(var.device), min_var.to(var.device) + ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return ratio + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + images: torch.Tensor | PIL.Image.Image | list[torch.Tensor] | list[PIL.Image.Image] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 20, + timesteps: list[float] = None, + guidance_scale: float = 4.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_embeds_pooled: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds_pooled: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pt", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 60): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_pooled (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` + input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. If + not provided, image embeddings will be generated from `image` input argument if existing. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + # 0. Define commonly used variables + device = self._execution_device + dtype = next(self.prior.parameters()).dtype + self._guidance_scale = guidance_scale + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + images=images, + image_embeds=image_embeds, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Encode caption + images + ( + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) = self.encode_prompt( + prompt=prompt, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) + + if images is not None: + image_embeds_pooled, uncond_image_embeds_pooled = self.encode_image( + images=images, + device=device, + dtype=dtype, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + elif image_embeds is not None: + image_embeds_pooled = image_embeds.repeat(batch_size * num_images_per_prompt, 1, 1) + uncond_image_embeds_pooled = torch.zeros_like(image_embeds_pooled) + else: + image_embeds_pooled = torch.zeros( + batch_size * num_images_per_prompt, + 1, + self.prior.config.clip_image_in_channels, + device=device, + dtype=dtype, + ) + uncond_image_embeds_pooled = torch.zeros( + batch_size * num_images_per_prompt, + 1, + self.prior.config.clip_image_in_channels, + device=device, + dtype=dtype, + ) + + if self.do_classifier_free_guidance: + image_embeds = torch.cat([image_embeds_pooled, uncond_image_embeds_pooled], dim=0) + else: + image_embeds = image_embeds_pooled + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + text_encoder_pooled = ( + torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled]) + if negative_prompt_embeds is not None + else prompt_embeds_pooled + ) + + # 4. Prepare and set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents( + batch_size, height, width, num_images_per_prompt, dtype, device, generator, latents, self.scheduler + ) + + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + else: + if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample: + self.scheduler.config.clip_sample = False # disample sample clipping + logger.warning(" set `clip_sample` to be False") + # 6. Run denoising loop + if hasattr(self.scheduler, "betas"): + alphas = 1.0 - self.scheduler.betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + else: + alphas_cumprod = [] + + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + if len(alphas_cumprod) > 0: + timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod) + timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) + else: + timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) + else: + timestep_ratio = t.expand(latents.size(0)).to(dtype) + # 7. Denoise image embeddings + predicted_image_embedding = self.prior( + sample=torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + timestep_ratio=torch.cat([timestep_ratio] * 2) if self.do_classifier_free_guidance else timestep_ratio, + clip_text_pooled=text_encoder_pooled, + clip_text=text_encoder_hidden_states, + clip_img=image_embeds, + return_dict=False, + )[0] + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale + ) + + # 9. Renoise latents to next timestep + if not isinstance(self.scheduler, DDPMWuerstchenScheduler): + timestep_ratio = t + latents = self.scheduler.step( + model_output=predicted_image_embedding, timestep=timestep_ratio, sample=latents, generator=generator + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "np": + latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesn't work + prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesn't work + negative_prompt_embeds = ( + negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None + ) # float() as bfloat16-> numpy doesn't work + + if not return_dict: + return ( + latents, + prompt_embeds, + prompt_embeds_pooled, + negative_prompt_embeds, + negative_prompt_embeds_pooled, + ) + + return StableCascadePriorPipelineOutput( + image_embeddings=latents, + prompt_embeds=prompt_embeds, + prompt_embeds_pooled=prompt_embeds_pooled, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_embeds_pooled=negative_prompt_embeds_pooled, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2eebf586ef816e1590a00402aa330979e17eb82 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/__init__.py @@ -0,0 +1,193 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_onnx_available, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["StableDiffusionPipelineOutput"]} + +if is_transformers_available() and is_flax_available(): + _import_structure["pipeline_output"].extend(["FlaxStableDiffusionPipelineOutput"]) +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["clip_image_project_model"] = ["CLIPImageProjection"] + _import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"] + _import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"] + _import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"] + _import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"] + _import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"] + _import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"] + _import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"] + _import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"] + _import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"] + _import_structure["stable_unclip_image_normalizer"] = ["StableUnCLIPImageNormalizer"] +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionImageVariationPipeline, + ) + + _dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline}) +else: + _import_structure["pipeline_stable_diffusion_image_variation"] = ["StableDiffusionImageVariationPipeline"] +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionDepth2ImgPipeline, + ) + + _dummy_objects.update( + { + "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, + } + ) +else: + _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] + +try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["pipeline_onnx_stable_diffusion"] = [ + "OnnxStableDiffusionPipeline", + "StableDiffusionOnnxPipeline", + ] + _import_structure["pipeline_onnx_stable_diffusion_img2img"] = ["OnnxStableDiffusionImg2ImgPipeline"] + _import_structure["pipeline_onnx_stable_diffusion_inpaint"] = ["OnnxStableDiffusionInpaintPipeline"] + _import_structure["pipeline_onnx_stable_diffusion_inpaint_legacy"] = ["OnnxStableDiffusionInpaintPipelineLegacy"] + _import_structure["pipeline_onnx_stable_diffusion_upscale"] = ["OnnxStableDiffusionUpscalePipeline"] + +if is_transformers_available() and is_flax_available(): + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + + _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) + _import_structure["pipeline_flax_stable_diffusion"] = ["FlaxStableDiffusionPipeline"] + _import_structure["pipeline_flax_stable_diffusion_img2img"] = ["FlaxStableDiffusionImg2ImgPipeline"] + _import_structure["pipeline_flax_stable_diffusion_inpaint"] = ["FlaxStableDiffusionInpaintPipeline"] + _import_structure["safety_checker_flax"] = ["FlaxStableDiffusionSafetyChecker"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .clip_image_project_model import CLIPImageProjection + from .pipeline_stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, + ) + from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline + from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline + from .pipeline_stable_diffusion_instruct_pix2pix import ( + StableDiffusionInstructPix2PixPipeline, + ) + from .pipeline_stable_diffusion_latent_upscale import ( + StableDiffusionLatentUpscalePipeline, + ) + from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline + from .pipeline_stable_unclip import StableUnCLIPPipeline + from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline + from .safety_checker import StableDiffusionSafetyChecker + from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + StableDiffusionImageVariationPipeline, + ) + else: + from .pipeline_stable_diffusion_image_variation import ( + StableDiffusionImageVariationPipeline, + ) + + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline + else: + from .pipeline_stable_diffusion_depth2img import ( + StableDiffusionDepth2ImgPipeline, + ) + + try: + if not (is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_onnx_objects import * + else: + from .pipeline_onnx_stable_diffusion import ( + OnnxStableDiffusionPipeline, + StableDiffusionOnnxPipeline, + ) + from .pipeline_onnx_stable_diffusion_img2img import ( + OnnxStableDiffusionImg2ImgPipeline, + ) + from .pipeline_onnx_stable_diffusion_inpaint import ( + OnnxStableDiffusionInpaintPipeline, + ) + from .pipeline_onnx_stable_diffusion_upscale import ( + OnnxStableDiffusionUpscalePipeline, + ) + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_objects import * + else: + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .pipeline_flax_stable_diffusion_img2img import ( + FlaxStableDiffusionImg2ImgPipeline, + ) + from .pipeline_flax_stable_diffusion_inpaint import ( + FlaxStableDiffusionInpaintPipeline, + ) + from .pipeline_output import FlaxStableDiffusionPipelineOutput + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/clip_image_project_model.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/clip_image_project_model.py new file mode 100644 index 0000000000000000000000000000000000000000..30dd90242d07d249e62e976c148a9d2cde02271e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/clip_image_project_model.py @@ -0,0 +1,29 @@ +# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class CLIPImageProjection(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, hidden_size: int = 768): + super().__init__() + self.hidden_size = hidden_size + self.project = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def forward(self, x): + return self.project(x) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa4bd5941d2d86fdab89557d6494cbcd9d11e68 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -0,0 +1,1876 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Conversion script for the Stable Diffusion checkpoints.""" + +import re +from contextlib import nullcontext +from io import BytesIO + +import requests +import torch +import yaml +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from ...models import ( + AutoencoderKL, + ControlNetModel, + PriorTransformer, + UNet2DConditionModel, +) +from ...schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from ...utils import is_accelerate_available, logging +from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT +from ...utils.torch_utils import get_device +from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from ..paint_by_example import PaintByExampleImageEncoder +from ..pipeline_utils import DiffusionPipeline +from .safety_checker import StableDiffusionSafetyChecker +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + if controlnet: + config["conditioning_channels"] = unet_params["hint_channels"] + else: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + _ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = [ + "DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D" + for i, _ in enumerate(block_out_channels) + ] + up_block_types = [ + "UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D" + for i, _ in enumerate(block_out_channels) + ][::-1] + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config["model"]["params"]["timesteps"], + beta_start=original_config["model"]["params"]["linear_start"], + beta_end=original_config["model"]["params"]["linear_end"], + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config["model"]["params"]["cond_stage_config"]["params"] + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in sorted(output_block_list.items())} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + else: + text_model = text_encoder + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("ln_final.weight", "text_model.final_layer_norm.weight"), + ("ln_final.bias", "text_model.final_layer_norm.bias"), + ("text_projection", "text_projection.weight"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + # text_model = CLIPTextModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 + # ) + try: + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + keys = list(checkpoint.keys()) + + keys_to_ignore = [] + if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: + # make sure to remove all keys > 22 + keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] + keys_to_ignore += ["cond_stage_model.model.text_projection"] + + text_model_dict = {} + + if prefix + "text_projection" in checkpoint: + d_model = int(checkpoint[prefix + "text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if key in keys_to_ignore: + continue + if key[len(prefix) :] in textenc_conversion_map: + if key.endswith("text_projection"): + value = checkpoint[key].T.contiguous() + else: + value = checkpoint[key] + + text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value + + if key.startswith(prefix + "transformer."): + new_key = key[len(prefix + "transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config, local_files_only=False): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config["model"]["params"]["embedder_config"] + + sd_clip_image_embedder_class = image_embedder_config["target"] + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: str | None = None, device: str | None = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config["model"]["params"]["noise_aug_config"] + noise_aug_class = noise_aug_config["target"] + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def download_from_original_stable_diffusion_ckpt( + checkpoint_path_or_dict: str | dict[str, torch.Tensor], + original_config_file: str = None, + image_size: int | None = None, + prediction_type: str = None, + model_type: str = None, + extract_ema: bool = False, + scheduler_type: str = "pndm", + num_in_channels: int | None = None, + upcast_attention: bool | None = None, + device: str = None, + from_safetensors: bool = False, + stable_unclip: str | None = None, + stable_unclip_prior: str | None = None, + clip_stats_path: str | None = None, + controlnet: bool | None = None, + adapter: bool | None = None, + load_safety_checker: bool = True, + safety_checker: StableDiffusionSafetyChecker | None = None, + feature_extractor: AutoFeatureExtractor | None = None, + pipeline_class: DiffusionPipeline = None, + local_files_only=False, + vae_path=None, + vae=None, + text_encoder=None, + text_encoder_2=None, + tokenizer=None, + tokenizer_2=None, + config_files=None, +) -> DiffusionPipeline: + """ + Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` + config file. + + Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the + global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is + recommended that you override the default values and/or supply an `original_config_file` wherever possible. + + Args: + checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + inferred by looking for a key that only exists in SD2.0 models. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 + Base. Use 768 for Stable Diffusion v2. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. + num_in_channels (`int`, *optional*, defaults to None): + The number of input channels. If `None`, it will be automatically inferred. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + model_type (`str`, *optional*, defaults to `None`): + The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", + "FrozenCLIPEmbedder", "PaintByExample"]`. + is_img2img (`bool`, *optional*, defaults to `False`): + Whether the model should be loaded as an img2img pipeline. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. This is necessary when running stable + diffusion 2.1. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + load_safety_checker (`bool`, *optional*, defaults to `True`): + Whether to load the safety checker or not. Defaults to `True`. + safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`): + Safety checker to use. If this parameter is `None`, the function will load a new instance of + [StableDiffusionSafetyChecker] by itself, if needed. + feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`): + Feature extractor to use. If this parameter is `None`, the function will load a new instance of + [AutoFeatureExtractor] by itself, if needed. + pipeline_class (`str`, *optional*, defaults to `None`): + The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + vae (`AutoencoderKL`, *optional*, defaults to `None`): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If + this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): + An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) + to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) + variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. + tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): + An instance of + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if + needed. + config_files (`dict[str, str]`, *optional*, defaults to `None`): + A dictionary mapping from config file names to their contents. If this parameter is `None`, the function + will load the config files by itself, if needed. Valid keys are: + - `v1`: Config file for Stable Diffusion v1 + - `v2`: Config file for Stable Diffusion v2 + - `xl`: Config file for Stable Diffusion XL + - `xl_refiner`: Config file for Stable Diffusion XL Refiner + return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + """ + + # import pipelines here to avoid circular import error when using from_single_file method + from diffusers import ( + LDMTextToImagePipeline, + PaintByExamplePipeline, + StableDiffusionControlNetPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLPipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + + if prediction_type == "v-prediction": + prediction_type = "v_prediction" + + if isinstance(checkpoint_path_or_dict, str): + if from_safetensors: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") + else: + if device is None: + device = get_device() + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + else: + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + elif isinstance(checkpoint_path_or_dict, dict): + checkpoint = checkpoint_path_or_dict + + # Sometimes models don't have the global_step item + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + logger.debug("global_step key not found in model") + global_step = None + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if original_config_file is None: + key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" + key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" + is_upscale = pipeline_class == StableDiffusionUpscalePipeline + + config_url = None + + # model_type = "v1" + if config_files is not None and "v1" in config_files: + original_config_file = config_files["v1"] + else: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + + if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: + # model_type = "v2" + if config_files is not None and "v2" in config_files: + original_config_file = config_files["v2"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + elif key_name_sd_xl_base in checkpoint: + # only base xl has two text embedders + if config_files is not None and "xl" in config_files: + original_config_file = config_files["xl"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" + elif key_name_sd_xl_refiner in checkpoint: + # only refiner xl has embedder and one text embedders + if config_files is not None and "xl_refiner" in config_files: + original_config_file = config_files["xl_refiner"] + else: + config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" + + if is_upscale: + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" + + if config_url is not None: + original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content) + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + else: + with open(original_config_file, "r") as f: + original_config_file = f.read() + + original_config = yaml.safe_load(original_config_file) + + # Convert the text model. + if ( + model_type is None + and "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None + ): + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + elif model_type is None and original_config["model"]["params"]["network_config"] is not None: + if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + if image_size is None: + image_size = 1024 + + if pipeline_class is None: + # Check if we have a SDXL or SD model and initialize default pipeline + if model_type not in ["SDXL", "SDXL-Refiner"]: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline + + if num_in_channels is None and pipeline_class in [ + StableDiffusionInpaintPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLControlNetInpaintPipeline, + ]: + num_in_channels = 9 + if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: + num_in_channels = 7 + elif num_in_channels is None: + num_in_channels = 4 + + if "unet_config" in original_config["model"]["params"]: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + elif "network_config" in original_config["model"]["params"]: + original_config["model"]["params"]["network_config"]["params"]["in_channels"] = num_in_channels + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + if image_size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + else: + if prediction_type is None: + prediction_type = "epsilon" + if image_size is None: + image_size = 512 + + if controlnet is None and "control_stage_config" in original_config["model"]["params"]: + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + controlnet = convert_controlnet_checkpoint( + checkpoint, original_config, path, image_size, upcast_attention, extract_ema + ) + + if "timesteps" in original_config["model"]["params"]: + num_train_timesteps = original_config["model"]["params"]["timesteps"] + else: + num_train_timesteps = 1000 + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_dict = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": num_train_timesteps, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", + } + scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) + scheduler_type = "euler" + else: + if "linear_start" in original_config["model"]["params"]: + beta_start = original_config["model"]["params"]["linear_start"] + else: + beta_start = 0.02 + + if "linear_end" in original_config["model"]["params"]: + beta_end = original_config["model"]["params"]["linear_end"] + else: + beta_end = 0.085 + scheduler = DDIMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + steps_offset=1, + clip_sample=False, + set_alpha_to_one=False, + prediction_type=prediction_type, + ) + # make sure scheduler works correctly with DDIM + scheduler.register_to_config(clip_sample=False) + + if scheduler_type == "pndm": + config = dict(scheduler.config) + config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(config) + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + elif scheduler_type == "ddim": + scheduler = scheduler + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + if pipeline_class == StableDiffusionUpscalePipeline: + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["upcast_attention"] = upcast_attention + + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=path, extract_ema=extract_ema + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + if model_type not in ["SDXL", "SDXL-Refiner"]: # SBM Delay this. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + if vae_path is None and vae is None: + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if ( + "model" in original_config + and "params" in original_config["model"] + and "scale_factor" in original_config["model"]["params"] + ): + vae_scaling_factor = original_config["model"]["params"]["scale_factor"] + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + elif vae is None: + vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + if text_encoder is None: + text_model = convert_open_clip_checkpoint( + checkpoint, config_name, local_files_only=local_files_only, **config_kwargs + ) + else: + text_model = text_encoder + + try: + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'." + ) + + if stable_unclip is None: + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + elif pipeline_class == StableDiffusionUpscalePipeline: + scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler" + ) + low_res_scheduler = DDPMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + ) + + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + low_res_scheduler=low_res_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if hasattr(pipe, "requires_safety_checker"): + pipe.requires_safety_checker = False + + else: + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, clip_stats_path=clip_stats_path, device=device + ) + + if stable_unclip == "img2img": + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + pipe = StableUnCLIPImg2ImgPipeline( + # image encoding components + feature_extractor=feature_extractor, + image_encoder=image_encoder, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + elif stable_unclip == "txt2img": + if stable_unclip_prior is None or stable_unclip_prior == "karlo": + karlo_model = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained( + karlo_model, subfolder="prior", local_files_only=local_files_only + ) + + try: + prior_tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + + prior_scheduler = UnCLIPScheduler.from_pretrained( + karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only + ) + prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) + else: + raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") + + pipe = StableUnCLIPPipeline( + # prior components + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_model, + prior=prior, + prior_scheduler=prior_scheduler, + # image noising components + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + # regular denoising components + tokenizer=tokenizer, + text_encoder=text_model, + unet=unet, + scheduler=scheduler, + # vae + vae=vae, + ) + else: + raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") + elif model_type == "PaintByExample": + vision_model = convert_paint_by_example_checkpoint(checkpoint) + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + try: + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." + ) + pipe = PaintByExamplePipeline( + vae=vae, + image_encoder=vision_model, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=feature_extractor, + ) + elif model_type == "FrozenCLIPEmbedder": + text_model = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=text_encoder + ) + try: + tokenizer = ( + CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + if tokenizer is None + else tokenizer + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if load_safety_checker: + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + elif model_type in ["SDXL", "SDXL-Refiner"]: + is_refiner = model_type == "SDXL-Refiner" + + if (is_refiner is False) and (tokenizer is None): + try: + tokenizer = CLIPTokenizer.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + if (is_refiner is False) and (text_encoder is None): + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + if tokenizer_2 is None: + try: + tokenizer_2 = CLIPTokenizer.from_pretrained( + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'." + ) + + if text_encoder_2 is None: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." if is_refiner else "conditioner.embedders.1.model." + + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + + if is_accelerate_available(): # SBM Now move model to cpu. + for param_name, param in converted_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + elif adapter: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + + else: + pipeline_kwargs = { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + "unet": unet, + "scheduler": scheduler, + } + + if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or ( + pipeline_class == StableDiffusionXLInpaintPipeline + ): + pipeline_kwargs.update({"requires_aesthetics_score": is_refiner}) + + if is_refiner: + pipeline_kwargs.update({"force_zeros_for_empty_prompt": False}) + + pipe = pipeline_class(**pipeline_kwargs) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only) + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: int | None = None, + upcast_attention: bool | None = None, + device: str = None, + from_safetensors: bool = False, + use_linear_projection: bool | None = None, + cross_attention_dim: bool | None = None, +) -> DiffusionPipeline: + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = get_device() + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + with open(original_config_file, "r") as f: + original_config_file = f.read() + original_config = yaml.safe_load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config["model"]["params"]: + raise ValueError("`control_stage_config` not present in original config") + + controlnet = convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, + ) + + return controlnet diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c15163917a7cd992e47abaee6d5ddefb5f89c435 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,470 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import deprecate, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + + >>> from diffusers import FlaxStableDiffusionPipeline + + >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16 + ... ) + + >>> prompt = "a photo of an astronaut riding a horse on mars" + + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> prompt_ids = pipeline.prepare_inputs(prompt) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + + >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: FlaxDDIMScheduler | FlaxPNDMScheduler | FlaxLMSDiscreteScheduler | FlaxDPMSolverMultistepScheduler, + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def prepare_inputs(self, prompt: str | list[str]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: jnp.ndarray | None = None, + neg_prompt_ids: jnp.ndarray | None = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int = 50, + height: int | None = None, + width: int | None = None, + guidance_scale: float | jnp.ndarray = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + latents (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + array is generated by sampling using the supplied random `generator`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be + removed in a > future release. + + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images).copy() + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i, 0] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 4, 5, 6), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c004ceb44c92d0782756530e854dba8734f1ee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -0,0 +1,525 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_img = init_img.resize((768, 512)) + + >>> prompts = "A fantasy landscape, trending on artstation" + + >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... revision="flax", + ... dtype=jnp.bfloat16, + ... ) + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + >>> prompt_ids, processed_image = pipeline.prepare_inputs( + ... prompt=[prompts] * num_samples, image=[init_img] * num_samples + ... ) + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipeline( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... strength=0.75, + ... num_inference_steps=50, + ... jit=True, + ... height=512, + ... width=768, + ... ).images + + >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: FlaxDDIMScheduler | FlaxPNDMScheduler | FlaxLMSDiscreteScheduler | FlaxDPMSolverMultistepScheduler, + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def prepare_inputs(self, prompt: str | list[str], image: Image.Image | list[Image.Image]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def get_timestep_start(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + return t_start + + def _generate( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: dict | FrozenDict, + prng_seed: jax.Array, + start_timestep: int, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + noise: jnp.ndarray | None = None, + neg_prompt_ids: jnp.ndarray | None = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if noise is None: + noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if noise.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") + + # Create init_latents + init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist + init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) + init_latents = self.vae.config.scaling_factor * init_latents + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) + + latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(start_timestep, num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.ndarray, + image: jnp.ndarray, + params: dict | FrozenDict, + prng_seed: jax.Array, + strength: float = 0.8, + num_inference_steps: int = 50, + height: int | None = None, + width: int | None = None, + guidance_scale: float | jnp.ndarray = 7.5, + noise: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt_ids (`jnp.ndarray`): + The prompt or prompts to guide image generation. + image (`jnp.ndarray`): + Array representing an image batch to be used as the starting point. + params (`Dict` or `FrozenDict`): + Dictionary containing the model parameters/weights. + prng_seed (`jax.Array` or `jax.Array`): + Array containing random number generator key. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + noise (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. The array is generated by + sampling using the supplied random `generator`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be + removed in a > future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + start_timestep = self.get_timestep_start(num_inference_steps, strength) + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 5, 6, 7, 8), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + start_timestep, + num_inference_steps, + height, + width, + guidance_scale, + noise, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..dbda93deebff4e253264dbff5510477d3bcbedce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -0,0 +1,582 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> import PIL + >>> import requests + >>> from io import BytesIO + >>> from diffusers import FlaxStableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( + ... "xvjiarui/stable-diffusion-2-inpainting" + ... ) + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> prng_seed = jax.random.PRNGKey(0) + >>> num_inference_steps = 50 + + >>> num_samples = jax.device_count() + >>> prompt = num_samples * [prompt] + >>> init_image = num_samples * [init_image] + >>> mask_image = num_samples * [mask_image] + >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( + ... prompt, init_image, mask_image + ... ) + # shard inputs and rng + + >>> params = replicate(params) + >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) + >>> prompt_ids = shard(prompt_ids) + >>> processed_masked_images = shard(processed_masked_images) + >>> processed_masks = shard(processed_masks) + + >>> images = pipeline( + ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True + ... ).images + >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + ``` +""" + + +class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): + r""" + Flax-based pipeline for text-guided image inpainting using Stable Diffusion. + + > [!WARNING] > 🧪 This is an experimental feature! + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.FlaxCLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`FlaxUNet2DConditionModel`]): + A `FlaxUNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: FlaxDDIMScheduler | FlaxPNDMScheduler | FlaxLMSDiscreteScheduler | FlaxDPMSolverMultistepScheduler, + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def prepare_inputs( + self, + prompt: str | list[str], + image: Image.Image | list[Image.Image], + mask: Image.Image | list[Image.Image], + ): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + if not isinstance(mask, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(mask, Image.Image): + mask = [mask] + + processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) + processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) + # processed_masks[processed_masks < 0.5] = 0 + processed_masks = processed_masks.at[processed_masks < 0.5].set(0) + # processed_masks[processed_masks >= 0.5] = 1 + processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) + + processed_masked_images = processed_images * (processed_masks < 0.5) + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids, processed_masked_images, processed_masks + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.ndarray, + mask: jnp.ndarray, + masked_image: jnp.ndarray, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: jnp.ndarray | None = None, + neg_prompt_ids: jnp.ndarray | None = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + latents_shape = ( + batch_size, + self.vae.config.latent_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + prng_seed, mask_prng_seed = jax.random.split(prng_seed) + + masked_image_latent_dist = self.vae.apply( + {"params": params["vae"]}, masked_image, method=self.vae.encode + ).latent_dist + masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + del mask_prng_seed + + mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") + + # 8. Check that sizes of mask, masked image and latents match + num_channels_latents = self.vae.config.latent_channels + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + def loop_body(step, args): + latents, mask, masked_image_latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + mask_input = jnp.concatenate([mask] * 2) + masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + # concat latents, mask, masked_image_latents in the channel dimension + latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, mask, masked_image_latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, mask, masked_image_latents, scheduler_state = loop_body( + i, (latents, mask, masked_image_latents, scheduler_state) + ) + else: + latents, _, _, _ = jax.lax.fori_loop( + 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) + ) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.ndarray, + mask: jnp.ndarray, + masked_image: jnp.ndarray, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int = 50, + height: int | None = None, + width: int | None = None, + guidance_scale: float | jnp.ndarray = 7.5, + latents: jnp.ndarray = None, + neg_prompt_ids: jnp.ndarray = None, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + latents (`jnp.ndarray`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + array is generated by sampling using the supplied random `generator`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. + + > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be + removed in a > future release. + + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated images + and the second element is a list of `bool`s indicating whether the corresponding generated image + contains "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") + mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + else: + images = self._generate( + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 6, 7, 8), +) +def _p_generate( + pipe, + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, +): + return pipe._generate( + prompt_ids, + mask, + masked_image, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess_image(image, dtype): + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, dtype): + w, h = mask.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w, h)) + mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 + mask = jnp.expand_dims(mask, axis=(0, 1)) + + return mask diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..0f66ca909e7dc951c8faedde53ac353f90b5bdc2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -0,0 +1,488 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) + + +class OnnxStableDiffusionPipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int | None, + do_classifier_free_guidance: bool, + negative_prompt: str | None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: str | list[str], + height: int | None, + width: int | None, + callback_steps: int, + negative_prompt: str | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: np.random.RandomState | None = None, + latents: np.ndarray | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or list[`PIL.Image.Image`] or `torch.Tensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + One or a list of [numpy generator(s)](TODO) to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + ): + deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." + deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb3eb6990e11752b6b8417c8da928eae60a9ead --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -0,0 +1,550 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int | None, + do_classifier_free_guidance: bool, + negative_prompt: str | None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: str | list[str], + callback_steps: int, + negative_prompt: str | list[str] | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def __call__( + self, + prompt: str | list[str], + image: np.ndarray | PIL.Image.Image = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: np.random.RandomState | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + image = preprocess(image).cpu().numpy() + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + init_latents = 0.18215 * init_latents + + if isinstance(prompt, str): + prompt = [prompt] + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3416a3383260f85834c9cd95549fe12321386c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -0,0 +1,564 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +NUM_UNET_INPUT_CHANNELS = 9 +NUM_LATENT_CHANNELS = 4 + + +def prepare_mask_and_masked_image(image, mask, latents_shape): + image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + image = image[None].transpose(0, 3, 1, 2) + image = image.astype(np.float32) / 127.5 - 1.0 + + image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) + masked_image = image * (image_mask < 127.5) + + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask, masked_image + + +class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: DDIMScheduler | PNDMScheduler | LMSDiscreteScheduler, + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int | None, + do_classifier_free_guidance: bool, + negative_prompt: str | None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: str | list[str], + height: int | None, + width: int | None, + callback_steps: int, + negative_prompt: str | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: int | None = 512, + width: int | None = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: np.random.RandomState | None = None, + latents: np.ndarray | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`np.ndarray`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + num_channels_latents = NUM_LATENT_CHANNELS + latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) + latents_dtype = prompt_embeds.dtype + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) + mask = mask.astype(latents.dtype) + masked_image = masked_image.astype(latents.dtype) + + masked_image_latents = self.vae_encoder(sample=masked_image)[0] + masked_image_latents = 0.18215 * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt + mask = mask.repeat(batch_size * num_images_per_prompt, 0) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) + + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + + unet_input_channels = NUM_UNET_INPUT_CHANNELS + if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # safety_checker does not support batched inputs yet + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..c721a4f2ffa000dc3364b6bc4becdbf1f1a603e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -0,0 +1,586 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers +from ...utils import deprecate, logging +from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + return image + + +class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline): + vae: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + low_res_scheduler: DDPMScheduler + scheduler: KarrasDiffusionSchedulers + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True + + def __init__( + self, + vae: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: Any, + unet: OnnxRuntimeModel, + low_res_scheduler: DDPMScheduler, + scheduler: KarrasDiffusionSchedulers, + safety_checker: OnnxRuntimeModel | None = None, + feature_extractor: CLIPImageProcessor | None = None, + max_noise_level: int = 350, + num_latent_channels=4, + num_unet_input_channels=7, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + low_res_scheduler=low_res_scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config( + max_noise_level=max_noise_level, + num_latent_channels=num_latent_channels, + num_unet_input_channels=num_unet_input_channels, + ) + + def check_inputs( + self, + prompt: str | list[str], + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, (list, np.ndarray)): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + return latents + + def decode_latents(self, latents): + latents = 1 / 0.08333 * latents + image = self.vae(latent_sample=latents)[0] + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + return image + + def _encode_prompt( + self, + prompt: str | list[str], + num_images_per_prompt: int | None, + do_classifier_free_guidance: bool, + negative_prompt: str | None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def __call__( + self, + prompt: str | list[str], + image: np.ndarray | PIL.Image.Image | list[PIL.Image.Image], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: np.random.RandomState | list[np.random.RandomState] | None = None, + latents: np.ndarray | None = None, + prompt_embeds: np.ndarray | None = None, + negative_prompt_embeds: np.ndarray | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, np.ndarray], None] | None = None, + callback_steps: int | None = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + image (`np.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + noise_level (`float`, defaults to 0.2): + Deteremines the amount of noise to add to the initial image before performing upscaling. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = preprocess(image).cpu().numpy() + height, width = image.shape[2:] + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.config.num_latent_channels, + height, + width, + latents_dtype, + generator, + ) + image = image.astype(latents_dtype) + + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # 5. Add noise to image + noise_level = np.array([noise_level]).astype(np.int64) + noise = generator.randn(*image.shape).astype(latents_dtype) + + image = self.low_res_scheduler.add_noise( + torch.from_numpy(image), torch.from_numpy(noise), torch.from_numpy(noise_level) + ) + image = image.numpy() + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) + noise_level = np.concatenate([noise_level] * image.shape[0]) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if self.config.num_latent_channels + num_channels_image != self.config.num_unet_input_channels: + raise ValueError( + "Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {self.config.num_unet_input_channels} but received `num_channels_latents`: {self.config.num_latent_channels} +" + f" `num_channels_image`: {num_channels_image} " + f" = {self.config.num_latent_channels + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = np.concatenate([latent_model_input, image], axis=1) + + # timestep to tensor + timestep = np.array([t], dtype=timestep_dtype) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + class_labels=noise_level, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + latents = latents.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 10. Post-processing + image = self.decode_latents(latents) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..804e5655d369a7db4123c8ce392e3b320dac792d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_output.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput, is_flax_available + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`list[bool]`) + list indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + + +if is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Flax-based Stable Diffusion pipelines. + + Args: + images (`np.ndarray`): + Denoised images of array shape of `(batch_size, height, width, num_channels)`. + nsfw_content_detected (`list[bool]`): + list indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content + or `None` if safety checking could not be performed. + """ + + images: np.ndarray + nsfw_content_detected: list[bool] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..42d62f53a20a15c37ec7beb376db66aa7e7d18ce --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,1107 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable + +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + """ + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + self._is_unet_config_sample_size_int = unet is not None and isinstance(unet.config.sample_size, int) + is_unet_sample_size_less_64 = ( + unet is not None + and hasattr(unet.config, "sample_size") + and self._is_unet_config_sample_size_int + and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + if not height or not width: + height = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[0] + ) + width = ( + self.unet.config.sample_size + if self._is_unet_config_sample_size_int + else self.unet.config.sample_size[1] + ) + height, width = height * self.vae_scale_factor, width * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + timestep_device = device + if XLA_AVAILABLE: + timestep_device = "cpu" + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None + ) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + if XLA_AVAILABLE: + xm.mark_step() + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec64c30763f7574c70167d32bb009abd571f5cf --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -0,0 +1,897 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DPTImageProcessor + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin): + r""" + Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + depth_estimator: DPTForDepthEstimation, + feature_extractor: DPTImageProcessor, + ): + super().__init__() + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + depth_estimator=depth_estimator, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): + if isinstance(image, PIL.Image.Image): + image = [image] + else: + image = list(image) + + if isinstance(image[0], PIL.Image.Image): + width, height = image[0].size + elif isinstance(image[0], np.ndarray): + width, height = image[0].shape[:-1] + else: + height, width = image[0].shape[-2:] + + if depth_map is None: + pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=device, dtype=dtype) + # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. + # So we use `torch.autocast` here for half precision inference. + if torch.backends.mps.is_available(): + autocast_ctx = contextlib.nullcontext() + logger.warning( + "The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS." + ) + else: + autocast_ctx = torch.autocast(device.type, dtype=dtype) + + with autocast_ctx: + depth_map = self.depth_estimator(pixel_values).predicted_depth + else: + depth_map = depth_map.to(device=device, dtype=dtype) + + depth_map = torch.nn.functional.interpolate( + depth_map.unsqueeze(1), + size=(height // self.vae_scale_factor, width // self.vae_scale_factor), + mode="bicubic", + align_corners=False, + ) + + depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) + depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) + depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 + depth_map = depth_map.to(dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if depth_map.shape[0] < batch_size: + repeat_by = batch_size // depth_map.shape[0] + depth_map = depth_map.repeat(repeat_by, 1, 1, 1) + + depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map + return depth_map + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + depth_map: torch.Tensor | None = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be used as the starting point. Can accept image + latents as `image` only if `depth_map` is not `None`. + depth_map (`torch.Tensor`, *optional*): + Depth prediction to be used as additional conditioning for the image generation process. If not + defined, it automatically predicts the depth with `self.depth_estimator`. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + ```py + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> from diffusers import StableDiffusionDepth2ImgPipeline + + >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-depth", + ... torch_dtype=torch.float16, + ... ) + >>> pipe.to("cuda") + + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> init_image = Image.open(requests.get(url, stream=True).raw) + >>> prompt = "two tigers" + >>> n_prompt = "bad, deformed, ugly, bad anotomy" + >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 1. Check inputs + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare depth mask + depth_mask = self.prepare_depth_map( + image, + depth_map, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + prompt_embeds.dtype, + device, + ) + + # 5. Preprocess image + image = self.image_processor.preprocess(image) + + # 6. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 7. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + depth_mask = callback_outputs.pop("depth_mask", depth_mask) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..5974aea6b44894745d7d685f02268b240200d04f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -0,0 +1,439 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline to generate image variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + model_cpu_offload_seq = "image_encoder->unet->vae" + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.Tensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + from diffusers import StableDiffusionImageVariationPipeline + from PIL import Image + from io import BytesIO + import requests + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "lambdalabs/sd-image-variations-diffusers", revision="v2.0" + ) + pipe = pipe.to("cuda") + + url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" + + response = requests.get(url) + image = Image.open(BytesIO(response.content)).convert("RGB") + + out = pipe(image, num_images_per_prompt=3, guidance_scale=15) + out["images"][0].save("result.jpg") + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + self.maybe_free_model_hooks() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..abcd06a2bb3bb4d28e85da55391952833293c12b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,1163 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "stable-diffusion-v1-5/stable-diffusion-v1-5" + >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.8, + num_inference_steps: int | None = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float | None = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float | None = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e09d1ed1ad7e4c34f426db5a8346b459deb8f5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,1361 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"] + + def __init__( + self, + vae: AutoencoderKL | AsymmetricAutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet is not None and unet.config.in_channels != 9: + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 1.0, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch + tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the + expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but + if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + condition_kwargs = {} + if isinstance(self.vae, AsymmetricAutoencoderKL): + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + init_image = self._encode_vae_image(init_image, generator=generator) + mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs + )[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py new file mode 100644 index 0000000000000000000000000000000000000000..c89d593d57beea7d4b398f2d56c37f08e822c30b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -0,0 +1,915 @@ +# Copyright 2025 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class StableDiffusionInstructPix2PixPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 100, + guidance_scale: float = 7.5, + image_guidance_scale: float = 1.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + cross_attention_kwargs: dict[str, Any] | None = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept + image latents as `image`, but if passing latents directly it is not encoded again. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Push the generated image towards the initial `image`. Image guidance scale is enabled by setting + `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely + linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a + value of at least `1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionInstructPix2PixPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + + >>> image = download_image(img_url).resize((512, 512)) + + >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( + ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "make the mountains snowy" + >>> image = pipe(prompt=prompt, image=image).images[0] + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Check inputs + self.check_inputs( + prompt, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._image_guidance_scale = image_guidance_scale + + device = self._execution_device + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 2. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + self.do_classifier_free_guidance, + ) + + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance\ + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + self.guidance_scale * (noise_pred_text - noise_pred_image) + + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + image_latents = callback_outputs.pop("image_latents", image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_ prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + else: + prompt_embeds_dtype = self.unet.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat( + [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] + ) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + ( + single_image_embeds, + single_negative_image_embeds, + single_negative_image_embeds, + ) = single_image_embeds.chunk(3) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat( + [single_image_embeds, single_negative_image_embeds, single_negative_image_embeds] + ) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..02dc483c277a2db00da926d3a29cd30c38cd499c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -0,0 +1,664 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin): + r""" + Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A [`EulerDiscreteScheduler`] to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: EulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") + + def _encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + device=device, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + **kwargs, + ) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None or pooled_prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_out = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + prompt_embeds = text_encoder_out.hidden_states[-1] + pooled_prompt_embeds = text_encoder_out.pooler_output + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_length=True, + return_tensors="pt", + ) + + uncond_encoder_out = self.text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + negative_prompt_embeds = uncond_encoder_out.hidden_states[-1] + negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, np.ndarray) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, (list, torch.Tensor)): + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] if image.ndim == 4 else 1 + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + negative_prompt: str | list[str] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image upscaling. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be upscaled. If it's a tensor, it can be either a + latent output from a Stable Diffusion model or an image tensor in the range `[-1, 1]`. It is considered + a `latent` if `image.shape[1]` is `4`; otherwise, it is considered to be an image representation and + encoded using this pipeline's `vae` encoder. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Examples: + ```py + >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline + >>> import torch + + >>> pipeline = StableDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ) + >>> pipeline.to("cuda") + + >>> model_id = "stabilityai/sd-x2-latent-upscaler" + >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> upscaler.to("cuda") + + >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" + >>> generator = torch.manual_seed(33) + + >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images + + >>> with torch.no_grad(): + ... image = pipeline.decode_latents(low_res_latents) + >>> image = pipeline.numpy_to_pil(image)[0] + + >>> image.save("../images/a1.png") + + >>> upscaled_image = upscaler( + ... prompt=prompt, + ... image=low_res_latents, + ... num_inference_steps=20, + ... guidance_scale=0, + ... generator=generator, + ... ).images[0] + + >>> upscaled_image.save("../images/a2.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None: + batch_size = 1 if isinstance(prompt, str) else len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if guidance_scale == 0: + prompt = [""] * batch_size + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + if image.shape[1] == 3: + # encode image if not in latent-space yet + image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = image[None, :] if image.ndim == 3 else image + image = torch.cat([image] * batch_multiplier) + + # 5. Add noise to image (set to be 0): + # (see below notes from the author): + # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." + noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) + noise_level = torch.cat([noise_level] * image.shape[0]) + inv_noise_level = (noise_level**2 + 1) ** (-0.5) + + image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] + image_cond = image_cond.to(prompt_embeds.dtype) + + noise_level_embed = torch.cat( + [ + torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device), + torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device), + ], + dim=1, + ) + + timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + height * 2, # 2x upscale + width * 2, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Denoising loop + num_warmup_steps = 0 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + sigma = self.scheduler.sigmas[i] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) + # preconditioning parameter based on Karras et al. (2022) (table 1) + timestep = torch.log(sigma) * 0.25 + + noise_pred = self.unet( + scaled_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_condition, + ).sample + + # in original repo, the output contains a variance channel that's not used + noise_pred = noise_pred[:, :-1] + + # apply preconditioning, based on table 1 in Karras et al. (2022) + inv_sigma = 1 / (sigma**2 + 1) + noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py new file mode 100644 index 0000000000000000000000000000000000000000..4befa44550b7f6453e73fc553ae22aeb0b065fe0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -0,0 +1,813 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionUpscalePipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-guided image super-resolution using Stable Diffusion 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + low_res_scheduler ([`SchedulerMixin`]): + A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of + [`DDPMScheduler`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["watermarker", "safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + scheduler: KarrasDiffusionSchedulers, + safety_checker: Any | None = None, + feature_extractor: CLIPImageProcessor | None = None, + watermarker: Any | None = None, + max_noise_level: int = 350, + ): + super().__init__() + + if hasattr( + vae, "config" + ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" + " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" + " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + safety_checker=safety_checker, + watermarker=watermarker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") + self.register_to_config(max_noise_level=max_noise_level) + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, nsfw_detected, watermark_detected = self.safety_checker( + images=image, + clip_input=safety_checker_input.pixel_values.to(dtype=dtype), + ) + else: + nsfw_detected = None + watermark_detected = None + + if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: + self.unet_offload_hook.offload() + + return image, nsfw_detected, watermark_detected + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, (list, np.ndarray, torch.Tensor)): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image` or tensor representing an image batch to be upscaled. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import StableDiffusionUpscalePipeline + >>> import torch + + >>> # load model and scheduler + >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" + >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( + ... model_id, variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + >>> prompt = "a white cat" + + >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] + >>> upscaled_image.save("upsampled_cat.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = torch.cat([image] * batch_multiplier * num_images_per_prompt) + noise_level = torch.cat([noise_level] * image.shape[0]) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=noise_level, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + + # Ensure latents are always the same type as the VAE + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 11. Apply watermark + if output_type == "pil" and self.watermarker is not None: + image = self.watermarker.apply_watermark(image) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..39857ed230e6b0488e089f5ccbb54007fc32cca2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -0,0 +1,952 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableUnCLIPPipeline + + >>> pipe = StableUnCLIPPipeline.from_pretrained( + ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 + ... ) # TODO update model path + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> images = pipe(prompt).images + >>> images[0].save("astronaut_horse.png") + ``` +""" + + +class StableUnCLIPPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + """ + Pipeline for text-to-image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + prior_tokenizer ([`CLIPTokenizer`]): + A [`CLIPTokenizer`]. + prior_text_encoder ([`CLIPTextModelWithProjection`]): + Frozen [`CLIPTextModelWithProjection`] text-encoder. + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + prior_scheduler ([`KarrasDiffusionSchedulers`]): + Scheduler used in the prior denoising process. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by the `noise_level`. + tokenizer ([`CLIPTokenizer`]): + A [`CLIPTokenizer`]. + text_encoder ([`CLIPTextModel`]): + Frozen [`CLIPTextModel`] text-encoder. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + _exclude_from_cpu_offload = ["prior", "image_normalizer"] + model_cpu_offload_seq = "text_encoder->prior_text_encoder->unet->vae" + + # prior components + prior_tokenizer: CLIPTokenizer + prior_text_encoder: CLIPTextModelWithProjection + prior: PriorTransformer + prior_scheduler: KarrasDiffusionSchedulers + + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + + vae: AutoencoderKL + + def __init__( + self, + # prior components + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModelWithProjection, + prior: PriorTransformer, + prior_scheduler: KarrasDiffusionSchedulers, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + prior_tokenizer=prior_tokenizer, + prior_text_encoder=prior_text_encoder, + prior=prior, + prior_scheduler=prior_scheduler, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder + def _encode_prior_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: CLIPTextModelOutput | tuple | None = None, + text_attention_mask: torch.Tensor | None = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.prior_tokenizer( + prompt, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.prior_tokenizer.batch_decode( + untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] + + prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) + + prompt_embeds = prior_text_encoder_output.text_embeds + text_enc_hid_states = prior_text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.prior_tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.prior_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( + uncond_input.input_ids.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds + uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_enc_hid_states, text_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler + def prepare_prior_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the prior_scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways: + 1. A noise schedule is applied directly to the embeddings. + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + self.image_normalizer.to(image_embeds.device) + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + # regular denoising process args + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 20, + guidance_scale: float = 10.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + noise_level: int = 0, + # prior args + prior_num_inference_steps: int = 25, + prior_guidance_scale: float = 4.0, + prior_latents: torch.Tensor | None = None, + clip_skip: int | None = None, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps in the prior denoising process. More denoising steps usually lead to a + higher quality image at the expense of slower inference. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + prior_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + embedding generation in the prior denoising process. Can be used to tweak the same generation with + different prompts. If not provided, a latents tensor is generated by sampling using the supplied random + `generator`. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning + a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 + + # 3. Encode input prompt + prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=prior_do_classifier_free_guidance, + ) + + # 4. Prepare prior timesteps + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + # 5. Prepare prior latent variables + embedding_dim = self.prior.config.embedding_dim + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prior_prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) + + # 7. Prior denoising loop + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents + latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prior_prompt_embeds, + encoder_hidden_states=prior_text_encoder_hidden_states, + attention_mask=prior_text_mask, + ).predicted_image_embedding + + if prior_do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + **prior_extra_step_kwargs, + return_dict=False, + )[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, prior_latents) + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeds = prior_latents + + # done prior + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 8. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 9. Prepare image embeddings + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + # 10. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 11. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + latents = self.prepare_latents( + shape=shape, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + scheduler=self.scheduler, + ) + + # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 13. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..bb96e5db02959cbcd8d18772eb18e3dcbb63d58f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -0,0 +1,858 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableUnCLIPImg2ImgPipeline + + >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(init_image, prompt).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +class StableUnCLIPImg2ImgPipeline( + DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin +): + """ + Pipeline for text-guided image-to-image generation using stable unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + feature_extractor ([`CLIPImageProcessor`]): + Feature extractor for image pre-processing before being encoded. + image_encoder ([`CLIPVisionModelWithProjection`]): + CLIP vision model for encoding images. + image_normalizer ([`StableUnCLIPImageNormalizer`]): + Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image + embeddings after the noise has been applied. + image_noising_scheduler ([`KarrasDiffusionSchedulers`]): + Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined + by the `noise_level`. + tokenizer (`~transformers.CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`)]. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen [`~transformers.CLIPTextModel`] text-encoder. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] to denoise the encoded image latents. + scheduler ([`KarrasDiffusionSchedulers`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _exclude_from_cpu_offload = ["image_normalizer"] + + # image encoding components + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + + # image noising components + image_normalizer: StableUnCLIPImageNormalizer + image_noising_scheduler: KarrasDiffusionSchedulers + + # regular denoising components + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModel + unet: UNet2DConditionModel + scheduler: KarrasDiffusionSchedulers + + vae: AutoencoderKL + + def __init__( + self, + # image encoding components + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + # image noising components + image_normalizer: StableUnCLIPImageNormalizer, + image_noising_scheduler: KarrasDiffusionSchedulers, + # regular denoising components + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + # vae + vae: AutoencoderKL, + ): + super().__init__() + + self.register_modules( + feature_extractor=feature_extractor, + image_encoder=image_encoder, + image_normalizer=image_normalizer, + image_noising_scheduler=image_noising_scheduler, + tokenizer=tokenizer, + text_encoder=text_encoder, + unet=unet, + scheduler=scheduler, + vae=vae, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def _encode_image( + self, + image, + device, + batch_size, + num_images_per_prompt, + do_classifier_free_guidance, + noise_level, + generator, + image_embeds, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if isinstance(image, PIL.Image.Image): + # the image embedding should repeated so it matches the total batch size of the prompt + repeat_by = batch_size + else: + # assume the image input is already properly batched and just needs to be repeated so + # it matches the num_images_per_prompt. + # + # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched + # `image_embeds`. If those happen to be common use cases, let's think harder about + # what the expected dimensions of inputs should be and how we handle the encoding. + repeat_by = num_images_per_prompt + + if image_embeds is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + + image_embeds = self.noise_image_embeddings( + image_embeds=image_embeds, + noise_level=noise_level, + generator=generator, + ) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeds = image_embeds.unsqueeze(1) + bs_embed, seq_len, _ = image_embeds.shape + image_embeds = image_embeds.repeat(1, repeat_by, 1) + image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) + image_embeds = image_embeds.squeeze(1) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + noise_level, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." + ) + + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: + raise ValueError( + f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." + ) + + if image is not None and image_embeds is not None: + raise ValueError( + "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." + ) + + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + + if image is not None: + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings + def noise_image_embeddings( + self, + image_embeds: torch.Tensor, + noise_level: int, + noise: torch.Tensor | None = None, + generator: torch.Generator | None = None, + ): + """ + Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher + `noise_level` increases the variance in the final un-noised images. + + The noise is applied in two ways: + 1. A noise schedule is applied directly to the embeddings. + 2. A vector of sinusoidal time embeddings are appended to the output. + + In both cases, the amount of noise is controlled by the same `noise_level`. + + The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. + """ + if noise is None: + noise = randn_tensor( + image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype + ) + + noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) + + self.image_normalizer.to(image_embeds.device) + image_embeds = self.image_normalizer.scale(image_embeds) + + image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) + + image_embeds = self.image_normalizer.unscale(image_embeds) + + noise_level = get_timestep_embedding( + timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, + # but we might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + noise_level = noise_level.to(image_embeds.dtype) + + image_embeds = torch.cat((image_embeds, noise_level), 1) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: torch.Tensor | PIL.Image.Image = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 20, + guidance_scale: float = 10, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + noise_level: int = 0, + image_embeds: torch.Tensor | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be + used or prompt is initialized to `""`. + image (`torch.Tensor` or `PIL.Image.Image`): + `Image` or tensor representing an image batch. The image is encoded to its CLIP embedding which the + `unet` is conditioned on. The image is _not_ encoded by the `vae` and then used as the latents in the + denoising process like it is in the standard Stable Diffusion text-guided image variation process. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + noise_level (`int`, *optional*, defaults to `0`): + The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in + the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated CLIP embeddings to condition the `unet` on. These latents are not used in the denoising + process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning + a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if prompt is None and prompt_embeds is None: + prompt = len(image) * [""] if isinstance(image, list) else "" + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + height=height, + width=width, + callback_steps=callback_steps, + noise_level=noise_level, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Encoder input image + noise_level = torch.tensor([noise_level], device=device) + image_embeds = self._encode_image( + image=image, + device=device, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + noise_level=noise_level, + generator=generator, + image_embeds=image_embeds, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + if latents is None: + latents = self.prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + class_labels=image_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/safety_checker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..65daafe01237bba4aca60bd56d6d32dfa5549cc0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,129 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import is_transformers_version, logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + # Model requires post_init after transformers v4.57.3 + if is_transformers_version(">", "4.57.3"): + self.post_init() + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if torch.is_tensor(images) or torch.is_tensor(images[0]): + images[idx] = torch.zeros_like(images[idx]) # black image + else: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/safety_checker_flax.py new file mode 100644 index 0000000000000000000000000000000000000000..77a31b91d52d03b9ad3475a089c7ea41cf273ec3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -0,0 +1,110 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPConfig, FlaxPreTrainedModel +from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule + + +def jax_cosine_distance(emb_1, emb_2, eps=1e-12): + norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T + norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T + return jnp.matmul(norm_emb_1, norm_emb_2.T) + + +class FlaxStableDiffusionSafetyCheckerModule(nn.Module): + config: CLIPConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) + + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) + self.special_care_embeds = self.param( + "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) + ) + + self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) + self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) + + def __call__(self, clip_input): + pooled_output = self.vision_model(clip_input)[1] + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment + special_scores = jnp.round(special_scores, 3) + is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) + # Use a lower threshold if an image has any special care concept + special_adjustment = is_special_care * 0.01 + + concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment + concept_scores = jnp.round(concept_scores, 3) + has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) + + return has_nsfw_concepts + + +class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + module_class = FlaxStableDiffusionSafetyCheckerModule + + def __init__( + self, + config: CLIPConfig, + input_shape: tuple | None = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = (1, 224, 224, 3) + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.Array, input_shape: tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensor + clip_input = jax.random.normal(rng, input_shape) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, clip_input)["params"] + + return random_params + + def __call__( + self, + clip_input, + params: dict = None, + ): + clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) + + return self.module.apply( + {"params": params or self.params}, + jnp.array(clip_input, dtype=jnp.float32), + rngs={}, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba91a0f2392398e60f3e837e7a4b5de28091dece --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py @@ -0,0 +1,55 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): + """ + This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. + + It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image + embeddings. + """ + + @register_to_config + def __init__( + self, + embedding_dim: int = 768, + ): + super().__init__() + + self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) + self.std = nn.Parameter(torch.ones(1, embedding_dim)) + + def to( + self, + torch_device: str | torch.device | None = None, + torch_dtype: torch.dtype | None = None, + ): + self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) + self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) + return self + + def scale(self, embeds): + embeds = (embeds - self.mean) * 1.0 / self.std + return embeds + + def unscale(self, embeds): + embeds = (embeds * self.std) + self.mean + return embeds diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0604589a208d0c5c1270629dcad144395b750cb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"] + _import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"] + _import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline + from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline + from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..43f801d81e2e00ae6563eb4f6aa8d2e614d813bc --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class StableDiffusion3PipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py new file mode 100644 index 0000000000000000000000000000000000000000..7764a79d7faf9f57658aa2772306878eb6890224 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -0,0 +1,1144 @@ +# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3Pipeline + + >>> pipe = StableDiffusion3Pipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> image = pipe(prompt).images[0] + >>> image.save("sd3.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`SiglipVisionModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`SiglipImageProcessor`, *optional*): + Image processor for IP Adapter. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + image_encoder: SiglipVisionModel = None, + feature_extractor: SiglipImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def skip_guidance_layers(self): + return self._skip_guidance_layers + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + skip_guidance_layers: list[int] = None, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_start: float = 0.01, + mu: float | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + skip_guidance_layers (`list[int]`, *optional*): + A list of integers that specify layers to skip during guidance. If not provided, all layers will be + used for guidance. If provided, the guidance will only be applied to the layers specified in the list. + Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9]. + skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in + `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` + with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers + with a scale of `1`. + skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in + `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. + skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in + `skip_guidance_layers` will start. The guidance will be applied to the layers specified in + `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by + StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._skip_layer_guidance_scale = skip_layer_guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + if skip_guidance_layers is not None: + original_prompt_embeds = prompt_embeds + original_pooled_prompt_embeds = pooled_prompt_embeds + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + _, _, height, width = latents.shape + image_seq_len = (height // self.transformer.config.patch_size) * ( + width // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + should_skip_layers = ( + True + if i > num_inference_steps * skip_layer_guidance_start + and i < num_inference_steps * skip_layer_guidance_stop + else False + ) + if skip_guidance_layers is not None and should_skip_layers: + timestep = t.expand(latents.shape[0]) + latent_model_input = latents + noise_pred_skip_layers = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=original_prompt_embeds, + pooled_projections=original_pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + skip_layers=skip_guidance_layers, + )[0] + noise_pred = ( + noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..7951b970cd0c882107b0e236ef8918bd595ca691 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -0,0 +1,1158 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + + >>> from diffusers import AutoPipelineForImage2Image + >>> from diffusers.utils import load_image + + >>> device = "cuda" + >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers" + >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`SiglipVisionModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`SiglipImageProcessor`, *optional*): + Image processor for IP Adapter. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + image_encoder: SiglipVisionModel | None = None, + feature_extractor: SiglipImageProcessor | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if image.shape[1] == self.vae.config.latent_channels: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + latents = init_latents.to(device=device, dtype=dtype) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + image: PipelineImageInput = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + mu: float | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Preprocess image + image = self.image_processor.preprocess(image, height=height, width=width) + + # 4. Prepare timesteps + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( + int(width) // self.vae_scale_factor // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas, **scheduler_kwargs + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 6. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..d3594b868f89992d109f34bfca5f2ecf9bb206f7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -0,0 +1,1384 @@ +# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import StableDiffusion3PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusion3InpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusion3InpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("sd3_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): + r""" + Args: + transformer ([`SD3Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant, + with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size` + as its dimension. + text_encoder_2 ([`CLIPTextModelWithProjection`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + text_encoder_3 ([`T5EncoderModel`]): + Frozen text-encoder. Stable Diffusion 3 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_3 (`T5TokenizerFast`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`SiglipVisionModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`SiglipImageProcessor`, *optional*): + Image processor for IP Adapter. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + + def __init__( + self, + transformer: SD3Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer_2: CLIPTokenizer, + text_encoder_3: T5EncoderModel, + tokenizer_3: T5TokenizerFast, + image_encoder: SiglipVisionModel | None = None, + feature_extractor: SiglipImageProcessor | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + text_encoder_3=text_encoder_3, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + tokenizer_3=tokenizer_3, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels + ) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if self.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + self.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0] + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, + ): + device = device or self._execution_device + + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.text_encoder, self.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, + ): + r""" + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if self.text_encoder is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=None, + negative_prompt_2=None, + negative_prompt_3=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_3 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)): + raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_3 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 16: + image_latents = image + else: + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: int = None, + width: int = None, + padding_mask_crop: int | None = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 7.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 256, + mu: float | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead + prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is + will be used instead + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + mask_image_latent (`torch.Tensor`, `list[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will be generated by `mask_image`. + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used instead + negative_prompt_3 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and + `text_encoder_3`. If not defined, `negative_prompt` is used instead + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + strength, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 3. Prepare timesteps + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( + int(width) // self.vae_scale_factor // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas, **scheduler_kwargs + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 4. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == 16 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 6. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # match the inpainting pipeline and will be updated with input + mask inpainting model later + if num_channels_transformer == 33: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if ( + num_channels_latents + num_channels_mask + num_channels_masked_image + != self.transformer.config.in_channels + ): + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" + f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.transformer` or your `mask_image` or `image` input." + ) + elif num_channels_transformer != 16: + raise ValueError( + f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}." + ) + + # 7. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if num_channels_transformer == 33: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if num_channels_transformer == 16: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + else: + image = latents + + do_denormalize = [True] * image.shape[0] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cce556fceb2379be482d383e380a09836d25ce3b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47fbf91cdb1708c0c5b4b08de314f1310a4cfe32 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f99da28d7773fa1adfa1a786686d48204e0409b0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/__pycache__/pipeline_stable_diffusion_attend_and_excite.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py new file mode 100644 index 0000000000000000000000000000000000000000..80b0c09bc9a5197612afb8ef9cbb4915fd875f5a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -0,0 +1,1113 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable + +import numpy as np +import torch +from torch.nn import functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import Attention +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionAttendAndExcitePipeline + + >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> prompt = "a cat and a frog" + + >>> # use get_indices function to find out indices of the tokens you want to alter + >>> pipe.get_indices(prompt) + {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} + + >>> token_indices = [2, 5] + >>> seed = 6141 + >>> generator = torch.Generator("cuda").manual_seed(seed) + + >>> images = pipe( + ... prompt=prompt, + ... token_indices=token_indices, + ... guidance_scale=7.5, + ... generator=generator, + ... num_inference_steps=50, + ... max_iter_to_alter=25, + ... ).images + + >>> image = images[0] + >>> image.save(f"../images/{prompt}_{seed}.png") + ``` +""" + + +class AttentionStore: + @staticmethod + def get_empty_store(): + return {"down": [], "mid": [], "up": []} + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= 0 and is_cross: + if attn.shape[1] == np.prod(self.attn_res): + self.step_store[place_in_unet].append(attn) + + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers: + self.cur_att_layer = 0 + self.between_steps() + + def between_steps(self): + self.attention_store = self.step_store + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def aggregate_attention(self, from_where: list[str]) -> torch.Tensor: + """Aggregates the attention across the different layers and heads at the specified resolution.""" + out = [] + attention_maps = self.get_average_attention() + for location in from_where: + for item in attention_maps[location]: + cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + + def reset(self): + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + + def __init__(self, attn_res): + """ + Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion + process + """ + self.num_att_layers = -1 + self.cur_att_layer = 0 + self.step_store = self.get_empty_store() + self.attention_store = {} + self.curr_step_index = 0 + self.attn_res = attn_res + + +class AttendExciteAttnProcessor: + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + query = attn.to_q(hidden_states) + + is_cross = encoder_hidden_states is not None + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + + # only need to store attention maps during the Attend and Excite process + if attention_probs.requires_grad: + self.attnstore(attention_probs, is_cross, self.place_in_unet) + + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class StableDiffusionAttendAndExcitePipeline( + DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin +): + r""" + Pipeline for text-to-image generation using Stable Diffusion and Attend-and-Excite. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + indices, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int) + indices_is_list_list_ints = ( + isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int) + ) + + if not indices_is_list_ints and not indices_is_list_list_ints: + raise TypeError("`indices` must be a list of ints or a list of a list of ints") + + if indices_is_list_ints: + indices_batch_size = 1 + elif indices_is_list_list_ints: + indices_batch_size = len(indices) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if indices_batch_size != prompt_batch_size: + raise ValueError( + f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @staticmethod + def _compute_max_attention_per_index( + attention_maps: torch.Tensor, + indices: list[int], + ) -> list[torch.Tensor]: + """Computes the maximum attention value for each of the tokens we wish to alter.""" + attention_for_text = attention_maps[:, :, 1:-1] + attention_for_text *= 100 + attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) + + # Shift indices since we removed the first token + indices = [index - 1 for index in indices] + + # Extract the maximum values + max_indices_list = [] + for i in indices: + image = attention_for_text[:, :, i] + smoothing = GaussianSmoothing().to(attention_maps.device) + input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") + image = smoothing(input).squeeze(0).squeeze(0) + max_indices_list.append(image.max()) + return max_indices_list + + def _aggregate_and_get_max_attention_per_token( + self, + indices: list[int], + ): + """Aggregates the attention for each token and computes the max activation value for each token to alter.""" + attention_maps = self.attention_store.aggregate_attention( + from_where=("up", "down", "mid"), + ) + max_attention_per_index = self._compute_max_attention_per_index( + attention_maps=attention_maps, + indices=indices, + ) + return max_attention_per_index + + @staticmethod + def _compute_loss(max_attention_per_index: list[torch.Tensor]) -> torch.Tensor: + """Computes the attend-and-excite loss using the maximum attention value for each token.""" + losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] + loss = max(losses) + return loss + + @staticmethod + def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: + """Update the latent according to the computed loss.""" + grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] + latents = latents - step_size * grad_cond + return latents + + def _perform_iterative_refinement_step( + self, + latents: torch.Tensor, + indices: list[int], + loss: torch.Tensor, + threshold: float, + text_embeddings: torch.Tensor, + step_size: float, + t: int, + max_refinement_steps: int = 20, + ): + """ + Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code + according to our loss objective until the given threshold is reached for all tokens. + """ + iteration = 0 + target_loss = max(0, 1.0 - threshold) + while loss > target_loss: + iteration += 1 + + latents = latents.clone().detach().requires_grad_(True) + self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=indices, + ) + + loss = self._compute_loss(max_attention_per_index) + + if loss != 0: + latents = self._update_latent(latents, loss, step_size) + + logger.info(f"\t Try {iteration}. loss: {loss}") + + if iteration >= max_refinement_steps: + logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") + break + + # Run one more time but don't compute gradients and update the latents. + # We just need to compute the new loss - the grad update will occur below + latents = latents.clone().detach().requires_grad_(True) + _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=indices, + ) + loss = self._compute_loss(max_attention_per_index) + logger.info(f"\t Finished with loss of: {loss}") + return loss, latents, max_attention_per_index + + def register_attention_control(self): + attn_procs = {} + cross_att_count = 0 + for name in self.unet.attn_processors.keys(): + if name.startswith("mid_block"): + place_in_unet = "mid" + elif name.startswith("up_blocks"): + place_in_unet = "up" + elif name.startswith("down_blocks"): + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) + + self.unet.set_attn_processor(attn_procs) + self.attention_store.num_att_layers = cross_att_count + + def get_indices(self, prompt: str) -> dict[str, int]: + """Utility function to list the indices of the tokens you wish to alte""" + ids = self.tokenizer(prompt).input_ids + indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} + return indices + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str], + token_indices: list[int] | list[list[int]], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + max_iter_to_alter: int = 25, + thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, + scale_factor: int = 20, + attn_res: tuple[int] | None = (16, 16), + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + token_indices (`list[int]`): + The token indices to alter with attend-and-excite. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + max_iter_to_alter (`int`, *optional*, defaults to `25`): + Number of denoising steps to apply attend-and-excite. The `max_iter_to_alter` denoising steps are when + attend-and-excite is applied. For example, if `max_iter_to_alter` is `25` and there are a total of `30` + denoising steps, the first `25` denoising steps applies attend-and-excite and the last `5` will not. + thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): + Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. + scale_factor (`int`, *optional*, default to 20): + Scale factor to control the step size of each attend-and-excite update. + attn_res (`tuple`, *optional*, default computed from width and height): + The 2D resolution of the semantic attention map. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + token_indices, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if attn_res is None: + attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) + self.attention_store = AttentionStore(attn_res) + original_attn_proc = self.unet.attn_processors + self.register_attention_control() + + # default config for step size from original repo + scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) + step_size = scale_factor * np.sqrt(scale_range) + + text_embeddings = ( + prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds + ) + + if isinstance(token_indices[0], int): + token_indices = [token_indices] + + indices = [] + + for ind in token_indices: + indices = indices + [ind] * num_images_per_prompt + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Attend and excite process + with torch.enable_grad(): + latents = latents.clone().detach().requires_grad_(True) + updated_latents = [] + for latent, index, text_embedding in zip(latents, indices, text_embeddings): + # Forward pass of denoising with text conditioning + latent = latent.unsqueeze(0) + text_embedding = text_embedding.unsqueeze(0) + + self.unet( + latent, + t, + encoder_hidden_states=text_embedding, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + self.unet.zero_grad() + + # Get max activation value for each subject token + max_attention_per_index = self._aggregate_and_get_max_attention_per_token( + indices=index, + ) + + loss = self._compute_loss(max_attention_per_index=max_attention_per_index) + + # If this is an iterative refinement step, verify we have reached the desired threshold for all + if i in thresholds.keys() and loss > 1.0 - thresholds[i]: + loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( + latents=latent, + indices=index, + loss=loss, + threshold=thresholds[i], + text_embeddings=text_embedding, + step_size=step_size[i], + t=t, + ) + + # Perform gradient update + if i < max_iter_to_alter: + if loss != 0: + latent = self._update_latent( + latents=latent, + loss=loss, + step_size=step_size[i], + ) + logger.info(f"Iteration {i} | Loss: {loss:0.4f}") + + updated_latents.append(latent) + + latents = torch.cat(updated_latents, dim=0) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +class GaussianSmoothing(torch.nn.Module): + """ + Arguments: + Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed separately for each channel in the input + using a depthwise convolution. + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the + gaussian kernel. dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + + # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 + def __init__( + self, + channels: int = 1, + kernel_size: int = 3, + sigma: float = 0.5, + dim: int = 2, + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * dim + if isinstance(sigma, float): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer("weight", kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) + + def forward(self, input): + """ + Arguments: + Apply gaussian filter to input. + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_diffedit/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_diffedit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2145edb96c6be124abf9e9a21b9a5e8a3f3d641 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_diffedit/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py new file mode 100644 index 0000000000000000000000000000000000000000..43bc2eb955c7951fe896d07c1bb377f86d2449d7 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1548 @@ +# Copyright 2025 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.Tensor`) + inverted latents tensor + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: torch.Tensor + images: list[PIL.Image.Image] | np.ndarray + + +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipeline.invert(image=init_image, prompt=prompt).latents + ``` +""" + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def preprocess_mask(mask, batch_size: int = 1): + if not isinstance(mask, torch.Tensor): + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = torch.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("`mask_image` should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask + + +class StableDiffusionDiffEditPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, +): + r""" + > [!WARNING] > This is an experimental feature! + + Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading and saving methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler ([`DDIMInverseScheduler`]): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_source_inputs( + self, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, + ): + if source_prompt is not None and source_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif source_prompt is None and source_prompt_embeds is None: + raise ValueError( + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif source_prompt is not None and ( + not isinstance(source_prompt, str) and not isinstance(source_prompt, list) + ): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") + + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: + raise ValueError( + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def generate_mask( + self, + image: torch.Tensor | PIL.Image.Image = None, + target_prompt: str | list[str] | None = None, + target_negative_prompt: str | list[str] | None = None, + target_prompt_embeds: torch.Tensor | None = None, + target_negative_prompt_embeds: torch.Tensor | None = None, + source_prompt: str | list[str] | None = None, + source_negative_prompt: str | list[str] | None = None, + source_prompt_embeds: torch.Tensor | None = None, + source_negative_prompt_embeds: torch.Tensor | None = None, + num_maps_per_mask: int | None = 10, + mask_encode_strength: float | None = 0.5, + mask_thresholding_ratio: float | None = 3.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str | None = "np", + cross_attention_kwargs: dict[str, Any] | None = None, + ): + r""" + Generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to be used for computing the mask. + target_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide semantic mask generation. If not defined, you need to pass + `prompt_embeds`. + target_negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + target_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + target_negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + source_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide semantic mask generation using DiffEdit. If not defined, you need to + pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide semantic mask generation away from using DiffEdit. If not defined, you + need to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs (prompt weighting). If not provided, text embeddings are generated from `source_prompt` input + argument. + source_negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs (prompt weighting). If not provided, text embeddings are generated from + `source_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using DiffEdit. + mask_encode_strength (`float`, *optional*, defaults to 0.5): + The strength of the noise maps sampled to generate the semantic mask using DiffEdit. Must be between 0 + and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before + mask binarization. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + `list[PIL.Image.Image]` or `np.array`: + When returning a `list[PIL.Image.Image]`, the list consists of a batch of single-channel binary images + with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`. If it's + `np.array`, the shape is `(batch_size, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 1. Check inputs (Provide dummy argument for callback_steps) + self.check_inputs( + target_prompt, + mask_encode_strength, + 1, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, + ) + + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if target_prompt is not None and isinstance(target_prompt, str): + batch_size = 1 + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) + else: + batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + target_negative_prompt_embeds, target_prompt_embeds = self.encode_prompt( + target_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + target_prompt_embeds = torch.cat([target_negative_prompt_embeds, target_prompt_embeds]) + + source_negative_prompt_embeds, source_prompt_embeds = self.encode_prompt( + source_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, + ) + if do_classifier_free_guidance: + source_prompt_embeds = torch.cat([source_negative_prompt_embeds, source_prompt_embeds]) + + # 4. Preprocess image + image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) + encode_timestep = timesteps[0] + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents( + image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator + ) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if do_classifier_free_guidance: + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_source, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + torch.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) + mask_image = semantic_mask_image.cpu().numpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.image_processor.numpy_to_pil(mask_image) + + # Offload all models + self.maybe_free_model_hooks() + + return mask_image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: str | list[str] | None = None, + image: torch.Tensor | PIL.Image.Image = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + decode_latents: bool = False, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to produce the inverted latents guided by `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent of the noising process to run latent inversion. Must be between 0 and 1. When + `inpaint_strength` is 1, the inversion process is run for the full number of iterations specified in + `num_inference_steps`. `image` is used as a reference for the inversion process, and adding more noise + increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + decodes all inverted latents for each timestep into a list of generated images. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the + [`~models.attention_processor.AttnProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction. + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback-Leibler divergence output. + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps. + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: + If `return_dict` is `True`, + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is the inverted latents tensors + ordered by increasing noise, and the second is the corresponding decoded images if `decode_latents` is + `True`, otherwise `None`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = self.image_processor.preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents( + image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator + ) + + # 5. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + if num_reg_steps > 0: + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + inverted_latents.append(latents.detach().clone()) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = torch.stack(list(reversed(inverted_latents)), 1) + + # 8. Post-processing + image = None + if decode_latents: + image = self.decode_latents(latents.flatten(0, 1)) + + # 9. Convert to PIL. + if decode_latents and output_type == "pil": + image = self.image_processor.numpy_to_pil(image) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + mask_image: torch.Tensor | PIL.Image.Image = None, + image_latents: torch.Tensor | PIL.Image.Image = None, + inpaint_strength: float | None = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + mask_image (`PIL.Image.Image`): + `Image` or tensor representing an image batch to mask the generated image. White pixels in the mask are + repainted, while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `torch.Tensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Indicates extent to inpaint the masked area. Must be between 0 and 1. When `inpaint_strength` is 1, the + denoising process is run on the masked area for the full number of iterations specified in + `num_inference_steps`. `image_latents` is used as a reference for the masked area, and adding more + noise to a region increases `inpaint_strength`. If `inpaint_strength` is 0, no inpainting occurs. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." + ) + if image_latents is None: + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess mask + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = torch.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 6. Preprocess image latents + if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents): + image_latents = torch.cat(image_latents).detach() + elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5: + image_latents = image_latents.detach() + else: + image_latents = self.image_processor.preprocess(image_latents).detach() + + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): + raise ValueError( + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}" + f" timesteps, but has batch size {image_latents.shape[0]} with latent images from" + f" {image_latents.shape[1]} timesteps." + ) + image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) + image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + latents = image_latents[0].clone() + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + image_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..147980cbf9e5c3418fc1854787ae37b25e4fed56 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"] + _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline + from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6144f831566dd5518c00e5681c84a7c08a06609b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/pipeline_stable_diffusion_gligen.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/pipeline_stable_diffusion_gligen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b11f9fd662b84fcf12cc451563f4d33bb21cc580 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/pipeline_stable_diffusion_gligen.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/pipeline_stable_diffusion_gligen_text_image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/pipeline_stable_diffusion_gligen_text_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f42f8d15433c7c34442bf01a943f84f41fae62c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/__pycache__/pipeline_stable_diffusion_gligen_text_image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5bc9376e53ef5a262d92af4b730b9b8705de27 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -0,0 +1,865 @@ +# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionGLIGENPipeline + >>> from diffusers.utils import load_image + + >>> # Insert objects described by text at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENPipeline.from_pretrained( + ... "masterful/gligen-1-4-inpainting-text-box", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> input_image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png" + ... ) + >>> prompt = "a birthday cake" + >>> boxes = [[0.2676, 0.6088, 0.4773, 0.7183]] + >>> phrases = ["a birthday cake"] + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_inpaint_image=input_image, + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-1-4-inpainting-text-box.jpg") + + >>> # Generate an image described by the prompt and + >>> # insert objects described by text at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENPipeline.from_pretrained( + ... "masterful/gligen-1-4-generation-text-box", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a waterfall and a modern high speed train running through the tunnel in a beautiful forest with fall foliage" + >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]] + >>> phrases = ["a waterfall", "a modern high speed train running through the tunnel"] + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-1-4-generation-text-box.jpg") + ``` +""" + + +class StableDiffusionGLIGENPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + + _optional_components = ["safety_checker", "feature_extractor"] + model_cpu_offload_seq = "text_encoder->unet->vae" + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + gligen_phrases, + gligen_boxes, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if len(gligen_phrases) != len(gligen_boxes): + raise ValueError( + "length of `gligen_phrases` and `gligen_boxes` has to be same, but" + f" got: `gligen_phrases` {len(gligen_phrases)} != `gligen_boxes` {len(gligen_boxes)}" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_fuser(self, enabled=True): + for module in self.unet.modules(): + if type(module) is GatedSelfAttentionDense: + module.enabled = enabled + + def draw_inpaint_mask_from_boxes(self, boxes, size): + inpaint_mask = torch.ones(size[0], size[1]) + for box in boxes: + x0, x1 = box[0] * size[0], box[2] * size[0] + y0, y1 = box[1] * size[1], box[3] * size[1] + inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0 + return inpaint_mask + + def crop(self, im, new_width, new_height): + width, height = im.size + left = (width - new_width) / 2 + top = (height - new_height) / 2 + right = (width + new_width) / 2 + bottom = (height + new_height) / 2 + return im.crop((left, top, right, bottom)) + + def target_size_center_crop(self, im, new_hw): + width, height = im.size + if width != height: + im = self.crop(im, min(height, width), min(height, width)) + return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + gligen_scheduled_sampling_beta: float = 0.3, + gligen_phrases: list[str] = None, + gligen_boxes: list[list[float]] = None, + gligen_inpaint_image: PIL.Image.Image | None = None, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + gligen_phrases (`list[str]`): + The phrases to guide what to include in each of the regions defined by the corresponding + `gligen_boxes`. There should only be one phrase per bounding box. + gligen_boxes (`list[list[float]]`): + The bounding boxes that identify rectangular regions of the image that are going to be filled with the + content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a + `list[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1]. + gligen_inpaint_image (`PIL.Image.Image`, *optional*): + The input image, if provided, is inpainted with objects described by the `gligen_boxes` and + `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image. + gligen_scheduled_sampling_beta (`float`, defaults to 0.3): + Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image + Generation](https://huggingface.co/papers/2301.07093). Scheduled Sampling factor is only varied for + scheduled sampling during inference for improved quality and controllability. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + gligen_phrases, + gligen_boxes, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5.1 Prepare GLIGEN variables + max_objs = 30 + if len(gligen_boxes) > max_objs: + warnings.warn( + f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.", + FutureWarning, + ) + gligen_phrases = gligen_phrases[:max_objs] + gligen_boxes = gligen_boxes[:max_objs] + # prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask) + # Get tokens for phrases from pre-trained CLIPTokenizer + tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device) + # For the token, we use the same pre-trained text encoder + # to obtain its text feature + _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output + n_objs = len(gligen_boxes) + # For each entity, described in phrases, is denoted with a bounding box, + # we represent the location information as (xmin,ymin,xmax,ymax) + boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + boxes[:n_objs] = torch.tensor(gligen_boxes) + text_embeddings = torch.zeros( + max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype + ) + text_embeddings[:n_objs] = _text_embeddings + # Generate a mask for each object that is entity described by phrases + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + masks[:n_objs] = 1 + + repeat_batch = batch_size * num_images_per_prompt + boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone() + masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone() + if do_classifier_free_guidance: + repeat_batch = repeat_batch * 2 + boxes = torch.cat([boxes] * 2) + text_embeddings = torch.cat([text_embeddings] * 2) + masks = torch.cat([masks] * 2) + masks[: repeat_batch // 2] = 0 + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs["gligen"] = {"boxes": boxes, "positive_embeddings": text_embeddings, "masks": masks} + + # Prepare latent variables for GLIGEN inpainting + if gligen_inpaint_image is not None: + # if the given input image is not of the same size as expected by VAE + # center crop and resize the input image to expected shape + if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): + gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) + # Convert a single image into a batch of images with a batch size of 1 + # The resulting shape becomes (1, C, H, W), where C is the number of channels, + # and H and W are the height and width of the image. + # scales the pixel values to a range [-1, 1] + gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image) + gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) + # Run AutoEncoder to get corresponding latents + gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() + gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent + # Generate an inpainting mask + # pixel value = 0, where the object is present (defined by bounding boxes above) + # 1, everywhere else + gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:]) + gligen_inpaint_mask = gligen_inpaint_mask.to( + dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device + ) + gligen_inpaint_mask = gligen_inpaint_mask[None, None] + gligen_inpaint_mask_addition = torch.cat( + (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1 + ) + # Convert a single mask into a batch of masks with a batch size of 1 + gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() + + num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps)) + self.enable_fuser(True) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Scheduled sampling + if i == num_grounding_steps: + self.enable_fuser(False) + + if latents.shape[1] != 4: + latents = torch.randn_like(latents[:, :4]) + + if gligen_inpaint_image is not None: + gligen_inpaint_latent_with_noise = ( + self.scheduler.add_noise( + gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t]) + ) + .expand(latents.shape[0], -1, -1, -1) + .clone() + ) + latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( + 1 - gligen_inpaint_mask + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if gligen_inpaint_image is not None: + latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py new file mode 100644 index 0000000000000000000000000000000000000000..62e8a9fa95aeb1c9ef923da2dd2b8a878259024b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -0,0 +1,1053 @@ +# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.clip_image_project_model import CLIPImageProjection +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionGLIGENTextImagePipeline + >>> from diffusers.utils import load_image + + >>> # Insert objects described by image at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( + ... "anhnct/Gligen_Inpainting_Text_Image", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> input_image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png" + ... ) + >>> prompt = "a backpack" + >>> boxes = [[0.2676, 0.4088, 0.4773, 0.7183]] + >>> phrases = None + >>> gligen_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/backpack.jpeg" + ... ) + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_inpaint_image=input_image, + ... gligen_boxes=boxes, + ... gligen_images=[gligen_image], + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-inpainting-text-image-box.jpg") + + >>> # Generate an image described by the prompt and + >>> # insert objects described by text and image at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( + ... "anhnct/Gligen_Text_Image", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a flower sitting on the beach" + >>> boxes = [[0.0, 0.09, 0.53, 0.76]] + >>> phrases = ["flower"] + >>> gligen_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/pexels-pixabay-60597.jpg" + ... ) + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=phrases, + ... gligen_images=[gligen_image], + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-generation-text-image-box.jpg") + + >>> # Generate an image described by the prompt and + >>> # transfer style described by image at the region defined by bounding boxes + >>> pipe = StableDiffusionGLIGENTextImagePipeline.from_pretrained( + ... "anhnct/Gligen_Text_Image", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a dragon flying on the sky" + >>> boxes = [[0.4, 0.2, 1.0, 0.8], [0.0, 1.0, 0.0, 1.0]] # Set `[0.0, 1.0, 0.0, 1.0]` for the style + + >>> gligen_image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png" + ... ) + + >>> gligen_placeholder = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png" + ... ) + + >>> images = pipe( + ... prompt=prompt, + ... gligen_phrases=[ + ... "dragon", + ... "placeholder", + ... ], # Can use any text instead of `placeholder` token, because we will use mask here + ... gligen_images=[ + ... gligen_placeholder, + ... gligen_image, + ... ], # Can use any image in gligen_placeholder, because we will use mask here + ... input_phrases_mask=[1, 0], # Set 0 for the placeholder token + ... input_images_mask=[0, 1], # Set 0 for the placeholder image + ... gligen_boxes=boxes, + ... gligen_scheduled_sampling_beta=1, + ... output_type="pil", + ... num_inference_steps=50, + ... ).images + + >>> images[0].save("./gligen-generation-text-image-box-style-transfer.jpg") + ``` +""" + + +class StableDiffusionGLIGENTextImagePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + processor ([`~transformers.CLIPProcessor`]): + A `CLIPProcessor` to process reference image. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + image_project ([`CLIPImageProjection`]): + A `CLIPImageProjection` to project image embedding into phrases embedding space. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + _last_supported_version = "0.33.1" + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + processor: CLIPProcessor, + image_encoder: CLIPVisionModelWithProjection, + image_project: CLIPImageProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + processor=processor, + image_project=image_project, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + gligen_images, + gligen_phrases, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if gligen_images is not None and gligen_phrases is not None: + if len(gligen_images) != len(gligen_phrases): + raise ValueError( + "`gligen_images` and `gligen_phrases` must have the same length when both are provided, but" + f" got: `gligen_images` with length {len(gligen_images)} != `gligen_phrases` with length {len(gligen_phrases)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def enable_fuser(self, enabled=True): + for module in self.unet.modules(): + if type(module) is GatedSelfAttentionDense: + module.enabled = enabled + + def draw_inpaint_mask_from_boxes(self, boxes, size): + """ + Create an inpainting mask based on given boxes. This function generates an inpainting mask using the provided + boxes to mark regions that need to be inpainted. + """ + inpaint_mask = torch.ones(size[0], size[1]) + for box in boxes: + x0, x1 = box[0] * size[0], box[2] * size[0] + y0, y1 = box[1] * size[1], box[3] * size[1] + inpaint_mask[int(y0) : int(y1), int(x0) : int(x1)] = 0 + return inpaint_mask + + def crop(self, im, new_width, new_height): + """ + Crop the input image to the specified dimensions. + """ + width, height = im.size + left = (width - new_width) / 2 + top = (height - new_height) / 2 + right = (width + new_width) / 2 + bottom = (height + new_height) / 2 + return im.crop((left, top, right, bottom)) + + def target_size_center_crop(self, im, new_hw): + """ + Crop and resize the image to the target size while keeping the center. + """ + width, height = im.size + if width != height: + im = self.crop(im, min(height, width), min(height, width)) + return im.resize((new_hw, new_hw), PIL.Image.LANCZOS) + + def complete_mask(self, has_mask, max_objs, device): + """ + Based on the input mask corresponding value `0 or 1` for each phrases and image, mask the features + corresponding to phrases and images. + """ + mask = torch.ones(1, max_objs).type(self.text_encoder.dtype).to(device) + if has_mask is None: + return mask + + if isinstance(has_mask, int): + return mask * has_mask + else: + for idx, value in enumerate(has_mask): + mask[0, idx] = value + return mask + + def get_clip_feature(self, input, normalize_constant, device, is_image=False): + """ + Get image and phrases embedding by using CLIP pretrain model. The image embedding is transformed into the + phrases embedding space through a projection. + """ + if is_image: + if input is None: + return None + inputs = self.processor(images=[input], return_tensors="pt").to(device) + inputs["pixel_values"] = inputs["pixel_values"].to(self.image_encoder.dtype) + + outputs = self.image_encoder(**inputs) + feature = outputs.image_embeds + feature = self.image_project(feature).squeeze(0) + feature = (feature / feature.norm()) * normalize_constant + feature = feature.unsqueeze(0) + else: + if input is None: + return None + inputs = self.tokenizer(input, return_tensors="pt", padding=True).to(device) + outputs = self.text_encoder(**inputs) + feature = outputs.pooler_output + return feature + + def get_cross_attention_kwargs_with_grounded( + self, + hidden_size, + gligen_phrases, + gligen_images, + gligen_boxes, + input_phrases_mask, + input_images_mask, + repeat_batch, + normalize_constant, + max_objs, + device, + ): + """ + Prepare the cross-attention kwargs containing information about the grounded input (boxes, mask, image + embedding, phrases embedding). + """ + phrases, images = gligen_phrases, gligen_images + images = [None] * len(phrases) if images is None else images + phrases = [None] * len(images) if phrases is None else phrases + + boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + image_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + image_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + + text_features = [] + image_features = [] + for phrase, image in zip(phrases, images): + text_features.append(self.get_clip_feature(phrase, normalize_constant, device, is_image=False)) + image_features.append(self.get_clip_feature(image, normalize_constant, device, is_image=True)) + + for idx, (box, text_feature, image_feature) in enumerate(zip(gligen_boxes, text_features, image_features)): + boxes[idx] = torch.tensor(box) + masks[idx] = 1 + if text_feature is not None: + phrases_embeddings[idx] = text_feature + phrases_masks[idx] = 1 + if image_feature is not None: + image_embeddings[idx] = image_feature + image_masks[idx] = 1 + + input_phrases_mask = self.complete_mask(input_phrases_mask, max_objs, device) + phrases_masks = phrases_masks.unsqueeze(0).repeat(repeat_batch, 1) * input_phrases_mask + input_images_mask = self.complete_mask(input_images_mask, max_objs, device) + image_masks = image_masks.unsqueeze(0).repeat(repeat_batch, 1) * input_images_mask + boxes = boxes.unsqueeze(0).repeat(repeat_batch, 1, 1) + masks = masks.unsqueeze(0).repeat(repeat_batch, 1) + phrases_embeddings = phrases_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1) + image_embeddings = image_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1) + + out = { + "boxes": boxes, + "masks": masks, + "phrases_masks": phrases_masks, + "image_masks": image_masks, + "phrases_embeddings": phrases_embeddings, + "image_embeddings": image_embeddings, + } + + return out + + def get_cross_attention_kwargs_without_grounded(self, hidden_size, repeat_batch, max_objs, device): + """ + Prepare the cross-attention kwargs without information about the grounded input (boxes, mask, image embedding, + phrases embedding) (All are zero tensor). + """ + boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype) + masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + image_masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype) + phrases_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + image_embeddings = torch.zeros(max_objs, hidden_size, device=device, dtype=self.text_encoder.dtype) + + out = { + "boxes": boxes.unsqueeze(0).repeat(repeat_batch, 1, 1), + "masks": masks.unsqueeze(0).repeat(repeat_batch, 1), + "phrases_masks": phrases_masks.unsqueeze(0).repeat(repeat_batch, 1), + "image_masks": image_masks.unsqueeze(0).repeat(repeat_batch, 1), + "phrases_embeddings": phrases_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1), + "image_embeddings": image_embeddings.unsqueeze(0).repeat(repeat_batch, 1, 1), + } + + return out + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + gligen_scheduled_sampling_beta: float = 0.3, + gligen_phrases: list[str] = None, + gligen_images: list[PIL.Image.Image] = None, + input_phrases_mask: int | list[int] = None, + input_images_mask: int | list[int] = None, + gligen_boxes: list[list[float]] = None, + gligen_inpaint_image: PIL.Image.Image | None = None, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + gligen_normalize_constant: float = 28.7, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + gligen_phrases (`list[str]`): + The phrases to guide what to include in each of the regions defined by the corresponding + `gligen_boxes`. There should only be one phrase per bounding box. + gligen_images (`list[PIL.Image.Image]`): + The images to guide what to include in each of the regions defined by the corresponding `gligen_boxes`. + There should only be one image per bounding box + input_phrases_mask (`int` or `list[int]`): + pre phrases mask input defined by the correspongding `input_phrases_mask` + input_images_mask (`int` or `list[int]`): + pre images mask input defined by the correspongding `input_images_mask` + gligen_boxes (`list[list[float]]`): + The bounding boxes that identify rectangular regions of the image that are going to be filled with the + content described by the corresponding `gligen_phrases`. Each rectangular box is defined as a + `list[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1]. + gligen_inpaint_image (`PIL.Image.Image`, *optional*): + The input image, if provided, is inpainted with objects described by the `gligen_boxes` and + `gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image. + gligen_scheduled_sampling_beta (`float`, defaults to 0.3): + Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image + Generation](https://huggingface.co/papers/2301.07093). Scheduled Sampling factor is only varied for + scheduled sampling during inference for improved quality and controllability. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + gligen_normalize_constant (`float`, *optional*, defaults to 28.7): + The normalize value of the image embedding. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + gligen_images, + gligen_phrases, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5.1 Prepare GLIGEN variables + max_objs = 30 + if len(gligen_boxes) > max_objs: + warnings.warn( + f"More that {max_objs} objects found. Only first {max_objs} objects will be processed.", + FutureWarning, + ) + gligen_phrases = gligen_phrases[:max_objs] + gligen_boxes = gligen_boxes[:max_objs] + gligen_images = gligen_images[:max_objs] + + repeat_batch = batch_size * num_images_per_prompt + + if do_classifier_free_guidance: + repeat_batch = repeat_batch * 2 + + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + hidden_size = prompt_embeds.shape[2] + + cross_attention_kwargs["gligen"] = self.get_cross_attention_kwargs_with_grounded( + hidden_size=hidden_size, + gligen_phrases=gligen_phrases, + gligen_images=gligen_images, + gligen_boxes=gligen_boxes, + input_phrases_mask=input_phrases_mask, + input_images_mask=input_images_mask, + repeat_batch=repeat_batch, + normalize_constant=gligen_normalize_constant, + max_objs=max_objs, + device=device, + ) + + cross_attention_kwargs_without_grounded = {} + cross_attention_kwargs_without_grounded["gligen"] = self.get_cross_attention_kwargs_without_grounded( + hidden_size=hidden_size, repeat_batch=repeat_batch, max_objs=max_objs, device=device + ) + + # Prepare latent variables for GLIGEN inpainting + if gligen_inpaint_image is not None: + # if the given input image is not of the same size as expected by VAE + # center crop and resize the input image to expected shape + if gligen_inpaint_image.size != (self.vae.sample_size, self.vae.sample_size): + gligen_inpaint_image = self.target_size_center_crop(gligen_inpaint_image, self.vae.sample_size) + # Convert a single image into a batch of images with a batch size of 1 + # The resulting shape becomes (1, C, H, W), where C is the number of channels, + # and H and W are the height and width of the image. + # scales the pixel values to a range [-1, 1] + gligen_inpaint_image = self.image_processor.preprocess(gligen_inpaint_image) + gligen_inpaint_image = gligen_inpaint_image.to(dtype=self.vae.dtype, device=self.vae.device) + # Run AutoEncoder to get corresponding latents + gligen_inpaint_latent = self.vae.encode(gligen_inpaint_image).latent_dist.sample() + gligen_inpaint_latent = self.vae.config.scaling_factor * gligen_inpaint_latent + # Generate an inpainting mask + # pixel value = 0, where the object is present (defined by bounding boxes above) + # 1, everywhere else + gligen_inpaint_mask = self.draw_inpaint_mask_from_boxes(gligen_boxes, gligen_inpaint_latent.shape[2:]) + gligen_inpaint_mask = gligen_inpaint_mask.to( + dtype=gligen_inpaint_latent.dtype, device=gligen_inpaint_latent.device + ) + gligen_inpaint_mask = gligen_inpaint_mask[None, None] + gligen_inpaint_mask_addition = torch.cat( + (gligen_inpaint_latent * gligen_inpaint_mask, gligen_inpaint_mask), dim=1 + ) + # Convert a single mask into a batch of masks with a batch size of 1 + gligen_inpaint_mask_addition = gligen_inpaint_mask_addition.expand(repeat_batch, -1, -1, -1).clone() + + int(gligen_scheduled_sampling_beta * len(timesteps)) + self.enable_fuser(True) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if latents.shape[1] != 4: + latents = torch.randn_like(latents[:, :4]) + + if gligen_inpaint_image is not None: + gligen_inpaint_latent_with_noise = ( + self.scheduler.add_noise( + gligen_inpaint_latent, torch.randn_like(gligen_inpaint_latent), torch.tensor([t]) + ) + .expand(latents.shape[0], -1, -1, -1) + .clone() + ) + latents = gligen_inpaint_latent_with_noise * gligen_inpaint_mask + latents * ( + 1 - gligen_inpaint_mask + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if gligen_inpaint_image is not None: + latent_model_input = torch.cat((latent_model_input, gligen_inpaint_mask_addition), dim=1) + + # predict the noise residual with grounded information + noise_pred_with_grounding = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # predict the noise residual without grounded information + noise_pred_without_grounding = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs_without_grounded, + ).sample + + # perform guidance + if do_classifier_free_guidance: + # Using noise_pred_text from noise residual with grounded information and noise_pred_uncond from noise residual without grounded information + _, noise_pred_text = noise_pred_with_grounding.chunk(2) + noise_pred_uncond, _ = noise_pred_without_grounding.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_with_grounding + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dae2affddd1fd5952f454ed9cee906277dcceb16 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_ldm3d"] = ["StableDiffusionLDM3DPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25deffa9d11ba792f7a529189295671e476e7c7c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39598d6e686a112c74a4667184170d7a725ddc49 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/__pycache__/pipeline_stable_diffusion_ldm3d.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6de144aa7e8beeec286c4fc7077b6f163c0592e0 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -0,0 +1,1040 @@ +# Copyright 2025 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessorLDM3D +from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> from diffusers import StableDiffusionLDM3DPipeline + + >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c") + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> output = pipe(prompt) + >>> rgb_image, depth_image = output.rgb, output.depth + >>> rgb_image[0].save("astronaut_ldm3d_rgb.jpg") + >>> depth_image[0].save("astronaut_ldm3d_depth.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class LDM3DPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + rgb (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + depth (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`list[bool]`) + list indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + rgb: list[PIL.Image.Image] | np.ndarray + depth: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + + +class StableDiffusionLDM3DPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + _last_supported_version = "0.33.1" + + r""" + Pipeline for text-to-image and 3D generation using LDM3D. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + rgb_feature_extractor_input = feature_extractor_input[0] + safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 49, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return ((rgb, depth), has_nsfw_concept) + + return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_panorama/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_panorama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7572db7236cd6bcfd7dd032abcb29fd5f67cf1c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_panorama/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..259fbd9334303d1cf305d8616d1027f37e473579 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -0,0 +1,1201 @@ +# Copyright 2025 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import DDIMScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler + + >>> model_ckpt = "stabilityai/stable-diffusion-2-base" + >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") + >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( + ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 + ... ) + + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of the dolomites" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionPanoramaPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, +): + _last_supported_version = "0.33.1" + + r""" + Pipeline for text-to-image generation using MultiDiffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -> torch.Tensor: + """ + Decode the given latents with padding for circular inference. + + Args: + latents (torch.Tensor): The input latents to decode. + padding (int, optional): The number of latents to add on each side for padding. Defaults to 8. + + Returns: + torch.Tensor: The decoded image with padding removed. + + Notes: + - The padding is added to remove boundary artifacts and improve the output quality. + - This would slightly increase the memory usage. + - The padding pixels are then removed from the decoded image. + + """ + latents = 1 / self.vae.config.scaling_factor * latents + latents_left = latents[..., :padding] + latents_right = latents[..., -padding:] + latents = torch.cat((latents_right, latents, latents_left), axis=-1) + image = self.vae.decode(latents, return_dict=False)[0] + padding_pix = self.vae_scale_factor * padding + image = image[..., padding_pix:-padding_pix] + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + def get_views( + self, + panorama_height: int, + panorama_width: int, + window_size: int = 64, + stride: int = 8, + circular_padding: bool = False, + ) -> list[tuple[int, int, int, int]]: + """ + Generates a list of views based on the given parameters. Here, we define the mappings F_i (see Eq. 7 in the + MultiDiffusion paper https://huggingface.co/papers/2302.08113). If panorama's height/width < window_size, + num_blocks of height/width should return 1. + + Args: + panorama_height (int): The height of the panorama. + panorama_width (int): The width of the panorama. + window_size (int, optional): The size of the window. Defaults to 64. + stride (int, optional): The stride value. Defaults to 8. + circular_padding (bool, optional): Whether to apply circular padding. Defaults to False. + + Returns: + list[tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains four integers + representing the start and end coordinates of the window in the panorama. + + """ + panorama_height /= 8 + panorama_width /= 8 + num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 + if circular_padding: + num_blocks_width = panorama_width // stride if panorama_width > window_size else 1 + else: + num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 + total_num_blocks = int(num_blocks_height * num_blocks_width) + views = [] + for i in range(total_num_blocks): + h_start = int((i // num_blocks_width) * stride) + h_end = h_start + window_size + w_start = int((i % num_blocks_width) * stride) + w_end = w_start + window_size + views.append((h_start, h_end, w_start, w_end)) + return views + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def clip_skip(self): + return self._clip_skip + + @property + def do_classifier_free_guidance(self): + return False + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = 512, + width: int | None = 2048, + num_inference_steps: int = 50, + timesteps: list[int] = None, + guidance_scale: float = 7.5, + view_batch_size: int = 1, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + circular_padding: bool = False, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs: Any, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 2048): + The width in pixels of the generated image. The width is kept high because the pipeline is supposed + generate panorama-like images. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + The timesteps at which to generate the images. If not specified, then the default timestep spacing + strategy of the scheduler is used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + view_batch_size (`int`, *optional*, defaults to 1): + The batch size to denoise split views. For some GPUs with high performance, higher view batch size can + speedup the generation and increase the VRAM usage. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescaling factor for the guidance embeddings. A value of 0.0 means no rescaling is applied. + circular_padding (`bool`, *optional*, defaults to `False`): + If set to `True`, circular padding is applied to ensure there are no stitching artifacts. Circular + padding allows the model to seamlessly generate a transition from the rightmost part of the image to + the leftmost part, maintaining consistency in a 360-degree sense. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list[str]`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Define panorama grid and initialize views for synthesis. + # prepare batch grid + views = self.get_views(height, width, circular_padding=circular_padding) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) + count = torch.zeros_like(latents) + value = torch.zeros_like(latents) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 8. Denoising loop + # Each denoising step also includes refinement of the latents with respect to the + # views. + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + count.zero_() + value.zero_() + + # generate views + # Here, we iterate through different spatial crops of the latents and denoise them. These + # denoised (latent) crops are then averaged to produce the final latent + # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the + # MultiDiffusion paper for more details: https://huggingface.co/papers/2302.08113 + # Batch views denoise + for j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) + # get the latents corresponding to the current view coordinates + if circular_padding: + latents_for_view = [] + for h_start, h_end, w_start, w_end in batch_view: + if w_end > latents.shape[3]: + # Add circular horizontal padding + latent_view = torch.cat( + ( + latents[:, :, h_start:h_end, w_start:], + latents[:, :, h_start:h_end, : w_end - latents.shape[3]], + ), + axis=-1, + ) + else: + latent_view = latents[:, :, h_start:h_end, w_start:w_end] + latents_for_view.append(latent_view) + latents_for_view = torch.cat(latents_for_view) + else: + latents_for_view = torch.cat( + [ + latents[:, :, h_start:h_end, w_start:w_end] + for h_start, h_end, w_start, w_end in batch_view + ] + ) + + # rematch block's scheduler status + self.scheduler.__dict__.update(views_scheduler_status[j]) + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + latents_for_view.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latents_for_view + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # repeat prompt_embeds for batch + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds_input, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_denoised_batch = self.scheduler.step( + noise_pred, t, latents_for_view, **extra_step_kwargs + ).prev_sample + + # save views scheduler status after sample + views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) + + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + if circular_padding and w_end > latents.shape[3]: + # Case for circular padding + value[:, :, h_start:h_end, w_start:] += latents_view_denoised[ + :, :, h_start:h_end, : latents.shape[3] - w_start + ] + value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[ + :, :, h_start:h_end, latents.shape[3] - w_start : + ] + count[:, :, h_start:h_end, w_start:] += 1 + count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1 + else: + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://huggingface.co/papers/2302.08113 + latents = torch.where(count > 0, value / count, value) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type != "latent": + if circular_padding: + image = self.decode_latents_with_padding(latents) + else: + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b35015a9f72913526ee4bdece969f74d545aecfb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING + +import numpy as np +import PIL +from PIL import Image + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + BaseOutput, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +@dataclass +class SafetyConfig(object): + WEAK = { + "sld_warmup_steps": 15, + "sld_guidance_scale": 20, + "sld_threshold": 0.0, + "sld_momentum_scale": 0.0, + "sld_mom_beta": 0.0, + } + MEDIUM = { + "sld_warmup_steps": 10, + "sld_guidance_scale": 1000, + "sld_threshold": 0.01, + "sld_momentum_scale": 0.3, + "sld_mom_beta": 0.4, + } + STRONG = { + "sld_warmup_steps": 7, + "sld_guidance_scale": 2000, + "sld_threshold": 0.025, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + MAX = { + "sld_warmup_steps": 0, + "sld_guidance_scale": 5000, + "sld_threshold": 1.0, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {} + +_additional_imports.update({"SafetyConfig": SafetyConfig}) + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_output": ["StableDiffusionSafePipelineOutput"], + "pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"], + "safety_checker": ["StableDiffusionSafetyChecker"], + } + ) + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import StableDiffusionSafePipelineOutput + from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe + from .safety_checker import SafeStableDiffusionSafetyChecker + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97adfd5a341a01c98a471cb8c35763675351c45c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c8fe2bb5e0baaacc7b265060ce0116996484b8b Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95436bd96d6e8b722f4d2164155c1b268d6eb6e2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/pipeline_stable_diffusion_safe.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b98edb5ef5d1656ce87c74fc81bffec93ddf29e8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/__pycache__/safety_checker.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6b784bb0e10256451d1d3950edeebb58a929cf0c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_output.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import ( + BaseOutput, +) + + +@dataclass +class StableDiffusionSafePipelineOutput(BaseOutput): + """ + Output class for Safe Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`list[bool]`) + list of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" + (nsfw) content, or `None` if no safety check was performed or no images were flagged. + applied_safety_concept (`str`) + The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + unsafe_images: list[PIL.Image.Image] | np.ndarray | None + applied_safety_concept: str | None diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py new file mode 100644 index 0000000000000000000000000000000000000000..26bb5128ba9bf57623ff19ab4abf690f07e65b0b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -0,0 +1,784 @@ +import inspect +import warnings +from typing import Callable + +import numpy as np +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...image_processor import PipelineImageInput +from ...loaders import IPAdapterMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import deprecate, is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from . import StableDiffusionSafePipelineOutput +from .safety_checker import SafeStableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipelineSafe(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin): + _last_supported_version = "0.33.1" + + r""" + Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: SafeStableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + safety_concept: str | None = ( + "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," + " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" + " abuse, brutality, cruelty" + ) + + if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5" + " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + def safety_concept(self): + r""" + Getter method for the safety concept used with SLD + + Returns: + `str`: The text describing the safety concept + """ + return self._safety_text_concept + + @safety_concept.setter + def safety_concept(self, concept): + r""" + Setter method for the safety concept used with SLD + + Args: + concept (`str`): + The text of the new safety concept + """ + self._safety_text_concept = concept + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + enable_safety_guidance, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # Encode the safety concept text + if enable_safety_guidance: + safety_concept_input = self.tokenizer( + [self._safety_text_concept], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] + + # duplicate safety embeddings for each generation per prompt, using mps friendly method + seq_len = safety_embeddings.shape[1] + safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) + safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance + sld, we need to do three forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing three forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) + + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype, enable_safety_guidance): + if self.safety_checker is not None: + images = image.copy() + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + flagged_images = np.zeros((2, *image.shape[1:])) + if any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead." + f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" + ) + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + flagged_images[idx] = images[idx] + image[idx] = np.zeros(image[idx].shape) # black image + else: + has_nsfw_concept = None + flagged_images = None + return image, has_nsfw_concept, flagged_images + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def perform_safety_guidance( + self, + enable_safety_guidance, + safety_momentum, + noise_guidance, + noise_pred_out, + i, + sld_guidance_scale, + sld_warmup_steps, + sld_threshold, + sld_momentum_scale, + sld_mom_beta, + ): + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + return noise_guidance, safety_momentum + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + sld_guidance_scale: float | None = 1000, + sld_warmup_steps: int | None = 10, + sld_threshold: float | None = 0.01, + sld_momentum_scale: float | None = 0.3, + sld_mom_beta: float | None = 0.4, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + sld_guidance_scale (`float`, *optional*, defaults to 1000): + If `sld_guidance_scale < 1`, safety guidance is disabled. + sld_warmup_steps (`int`, *optional*, defaults to 10): + Number of warmup steps for safety guidance. SLD is only be applied for diffusion steps greater than + `sld_warmup_steps`. + sld_threshold (`float`, *optional*, defaults to 0.01): + Threshold that separates the hyperplane between appropriate and inappropriate images. + sld_momentum_scale (`float`, *optional*, defaults to 0.3): + Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0, + momentum is disabled. Momentum is built up during warmup for diffusion steps smaller than + `sld_warmup_steps`. + sld_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous + momentum is kept. Momentum is built up during warmup for diffusion steps smaller than + `sld_warmup_steps`. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + + Examples: + + ```py + import torch + from diffusers import StableDiffusionPipelineSafe + from diffusers.pipelines.stable_diffusion_safe import SafetyConfig + + pipeline = StableDiffusionPipelineSafe.from_pretrained( + "AIML-TUDA/stable-diffusion-safe", torch_dtype=torch.float16 + ).to("cuda") + prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" + image = pipeline(prompt=prompt, **SafetyConfig.MEDIUM).images[0] + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance + if not enable_safety_guidance: + warnings.warn("Safety checker disabled!") + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if do_classifier_free_guidance: + if enable_safety_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds]) + else: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + safety_momentum = None + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond + + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, + torch.zeros_like(scale), + scale, + ) + + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept, flagged_images = self.run_safety_checker( + image, device, prompt_embeds.dtype, enable_safety_guidance + ) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + if flagged_images is not None: + flagged_images = self.numpy_to_pil(flagged_images) + + if not return_dict: + return ( + image, + has_nsfw_concept, + self._safety_text_concept if enable_safety_guidance else None, + flagged_images, + ) + + return StableDiffusionSafePipelineOutput( + images=image, + nsfw_content_detected=has_nsfw_concept, + applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, + unsafe_images=flagged_images, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6ad5f2a348c3a0aed208011e40e507fe0195e5 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_safe/safety_checker.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafeStableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.Tensor, images: torch.Tensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + return images, has_nsfw_concepts diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..378e0e57817f58a0a28afed5d6110f6ee3effb3a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb40b9e2630ad9bdf70ca0dfffba045b892e208 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__pycache__/pipeline_stable_diffusion_sag.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__pycache__/pipeline_stable_diffusion_sag.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1384dc134d4dfc5722e57d7bc9e1f54958c62d6 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/__pycache__/pipeline_stable_diffusion_sag.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf604911f0b62509c3a5b0c223e301d8c8ddbe1 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -0,0 +1,969 @@ +# Copyright 2025 Susung Hong and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionSAGPipeline + + >>> pipe = StableDiffusionSAGPipeline.from_pretrained( + ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, sag_scale=0.75).images[0] + ``` +""" + + +# processes and stores attention probabilities +class CrossAttnStoreProcessor: + def __init__(self): + self.attention_probs = None + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + self.attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(self.attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +# Modified to get self-attention guidance scale in this paper (https://huggingface.co/papers/2210.00939) as an input +class StableDiffusionSAGPipeline( + DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin +): + _last_supported_version = "0.33.1" + + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection | None = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + sag_scale: float = 0.75, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + sag_scale (`float`, *optional*, defaults to 0.75): + Chosen between [0, 1.0] for better quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the + `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # and `sag_scale` is` `s` of equation (16) + # of the self-attention guidance paper: https://huggingface.co/papers/2210.00939 + # `sag_scale = 0` means no self-attention guidance + do_self_attention_guidance = sag_scale > 0.0 + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + do_classifier_free_guidance, + ) + + if do_classifier_free_guidance: + image_embeds = [] + negative_image_embeds = [] + for tmp_image_embeds in ip_adapter_image_embeds: + single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2) + image_embeds.append(single_image_embeds) + negative_image_embeds.append(single_negative_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + if timesteps.dtype not in [torch.int16, torch.int32, torch.int64]: + raise ValueError( + f"{self.__class__.__name__} does not support using a scheduler of type {self.scheduler.__class__.__name__}. Please make sure to use one of 'DDIMScheduler, PNDMScheduler, DDPMScheduler, DEISMultistepScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler'." + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + if do_classifier_free_guidance: + added_uncond_kwargs = ( + {"image_embeds": negative_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7. Denoising loop + original_attn_proc = self.unet.attn_processors + store_processor = CrossAttnStoreProcessor() + self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + map_size = None + + def get_map_size(module, input, output): + nonlocal map_size + map_size = output[0].shape[-2:] + + with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform self-attention guidance with the stored self-attention map + if do_self_attention_guidance: + # classifier-free guidance produces two chunks of attention map + # and we only use unconditional one according to equation (25) + # in https://huggingface.co/papers/2210.00939 + if do_classifier_free_guidance: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) + # get the stored attention maps + uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) + ) + uncond_emb, _ = prompt_embeds.chunk(2) + # forward and give guidance + degraded_pred = self.unet( + degraded_latents, + t, + encoder_hidden_states=uncond_emb, + added_cond_kwargs=added_uncond_kwargs, + ).sample + noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) + else: + # DDIM-like prediction of x0 + pred_x0 = self.pred_x0(latents, noise_pred, t) + # get the stored attention maps + cond_attn = store_processor.attention_probs + # self-attention-based degrading of latents + degraded_latents = self.sag_masking( + pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) + ) + # forward and give guidance + degraded_pred = self.unet( + degraded_latents, + t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + noise_pred += sag_scale * (noise_pred - degraded_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def sag_masking(self, original_latents, attn_map, map_size, t, eps): + # Same masking process as in SAG paper: https://huggingface.co/papers/2210.00939 + bh, hw1, hw2 = attn_map.shape + b, latent_channel, latent_h, latent_w = original_latents.shape + h = self.unet.config.attention_head_dim + if isinstance(h, list): + h = h[-1] + + # Produce attention mask + attn_map = attn_map.reshape(b, h, hw1, hw2) + attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 + attn_mask = ( + attn_mask.reshape(b, map_size[0], map_size[1]) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) + ) + attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) + + # Blur according to the self-attention mask + degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) + degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + + # Noise it again to match the noise level + degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t[None]) + + return degraded_latents + + # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step + # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) + def pred_x0(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_original_sample + + def pred_epsilon(self, sample, model_output, timestep): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + if self.scheduler.config.prediction_type == "epsilon": + pred_eps = model_output + elif self.scheduler.config.prediction_type == "sample": + pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) + elif self.scheduler.config.prediction_type == "v_prediction": + pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_eps + + +# Gaussian blur +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8088fbcfceba205b9b908613f4ca3fdc579120e8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -0,0 +1,76 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]} + +if is_transformers_available() and is_flax_available(): + _import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"]) +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_xl"] = ["StableDiffusionXLPipeline"] + _import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"] + _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] + _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] + +if is_transformers_available() and is_flax_available(): + from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState + + _additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState}) + _import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline + from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline + from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline + from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_objects import * + else: + from .pipeline_flax_stable_diffusion_xl import ( + FlaxStableDiffusionXLPipeline, + ) + from .pipeline_output import FlaxStableDiffusionXLPipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b8fa14c7a61a5ecc2e1754273101532bf178447 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_flax_stable_diffusion_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_flax_stable_diffusion_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e99f689b2790567261fd1232ab253332c3b4913d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_flax_stable_diffusion_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cdfd8f2f286b7f7d0c273f7174288730c27684c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12fb7338c92a7048f10313f0aca7e3576c78381a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94f1b04b31e3cce2e0b8576a811fb1380b0efc65 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7233d5a9faef6e2557ec8aa513837862673cc887 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_inpaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_instruct_pix2pix.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_instruct_pix2pix.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccf10f1702b0967254941a249c3ae0e929e3f498 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/pipeline_stable_diffusion_xl_instruct_pix2pix.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e91f0d8dde62bf0c0193c662cc5f02cab6f99f Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/__pycache__/watermark.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..f80abc7073ed7406f9e3b6297ac340a86bf76091 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -0,0 +1,307 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import partial + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPTokenizer, FlaxCLIPTextModel + +from diffusers.utils import logging + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from .pipeline_output import FlaxStableDiffusionXLPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + + +class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): + def __init__( + self, + text_encoder: FlaxCLIPTextModel, + text_encoder_2: FlaxCLIPTextModel, + vae: FlaxAutoencoderKL, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: FlaxDDIMScheduler | FlaxPNDMScheduler | FlaxLMSDiscreteScheduler | FlaxDPMSolverMultistepScheduler, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + + def prepare_inputs(self, prompt: str | list[str]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + # Assume we have the two encoders + inputs = [] + for tokenizer in [self.tokenizer, self.tokenizer_2]: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + inputs.append(text_inputs.input_ids) + inputs = jnp.stack(inputs, axis=1) + return inputs + + def __call__( + self, + prompt_ids: jax.Array, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int = 50, + guidance_scale: float | jax.Array = 7.5, + height: int | None = None, + width: int | None = None, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + return_dict: bool = True, + output_type: str = None, + jit: bool = False, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + if isinstance(guidance_scale, float) and jit: + # Convert to a tensor so each device gets a copy. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + guidance_scale = guidance_scale[:, None] + + return_latents = output_type == "latent" + + if jit: + images = _p_generate( + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) + else: + images = self._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) + + if not return_dict: + return (images,) + + return FlaxStableDiffusionXLPipelineOutput(images=images) + + def get_embeddings(self, prompt_ids: jnp.array, params): + # We assume we have the two encoders + + # bs, encoder_input, seq_length + te_1_inputs = prompt_ids[:, 0, :] + te_2_inputs = prompt_ids[:, 1, :] + + prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True) + prompt_embeds = prompt_embeds["hidden_states"][-2] + prompt_embeds_2_out = self.text_encoder_2( + te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True + ) + prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2] + text_embeds = prompt_embeds_2_out["text_embeds"] + prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1) + return prompt_embeds, text_embeds + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype) + return add_time_ids + + def _generate( + self, + prompt_ids: jnp.array, + params: dict | FrozenDict, + prng_seed: jax.Array, + num_inference_steps: int, + height: int, + width: int, + guidance_scale: float, + latents: jnp.array | None = None, + neg_prompt_ids: jnp.array | None = None, + return_latents=False, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # Encode input prompt + prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params) + + # Get unconditional embeddings + batch_size = prompt_embeds.shape[0] + if neg_prompt_ids is None: + neg_prompt_embeds = jnp.zeros_like(prompt_embeds) + negative_pooled_embeds = jnp.zeros_like(pooled_embeds) + else: + neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params) + + add_time_ids = self._get_add_time_ids( + (height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype + ) + + prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048) + add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0) + add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0) + + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + + # Create random latents + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # Prepare scheduler state + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler_state.init_noise_sigma + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # Denoising loop + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + if return_latents: + return latents + + # Decode latents + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + +# Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None), + static_broadcasted_argnums=(0, 4, 5, 6, 10), +) +def _p_generate( + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, +): + return pipe._generate( + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + neg_prompt_ids, + return_latents, + ) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6617e7cd058cc701de501771d8a29e04cc929900 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_output.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput, is_flax_available + + +@dataclass +class StableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray + + +if is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionXLPipelineOutput(BaseOutput): + """ + Output class for Flax Stable Diffusion XL pipelines. + + Args: + images (`np.ndarray`) + Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. + """ + + images: np.ndarray diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6b105702e88cbfbe49652f93b704f436a00984 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -0,0 +1,1300 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..19ccfab3de0a9e8bf833388211e2d9e386b1e59e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -0,0 +1,1492 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" + + >>> init_image = load_image(url).convert("RGB") + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt, image=init_image).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLImg2ImgPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the + config of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if num_inference_steps is None: + raise ValueError("`num_inference_steps` cannot be None.") + elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + def prepare_latents( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + latents_mean = latents_std = None + if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None: + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1) + if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None: + latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1) + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + empty_device_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: + image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) + elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) + if latents_mean is not None and latents_std is not None: + latents_mean = latents_mean.to(device=device, dtype=dtype) + latents_std = latents_std.to(device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std + else: + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + strength: float = 0.3, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[np.ndarray]`): + The image(s) to modify with the pipeline. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of + `denoising_start` being declared as an integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image + Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. Prepare timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True if self.denoising_start is None else False + + # 6. Prepare latent variables + if latents is None: + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + add_noise, + ) + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 9. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 9.1 Apply denoising_end + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..8de7d4f0bb7d1bf22074b81ae07dc29d734af14d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -0,0 +1,1729 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", + ... torch_dtype=torch.float16, + ... variant="fp16", + ... use_safetensors=True, + ... ) + >>> pipe.to("cuda") + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = load_image(img_url).convert("RGB") + >>> mask_image = load_image(mask_url).convert("RGB") + + >>> prompt = "A majestic tiger sitting on a bench" + >>> image = pipe( + ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80 + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def mask_pil_to_torch(mask, height, width): + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask = torch.from_numpy(mask) + return mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLInpaintPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + FromSingleFileMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + "mask", + "masked_image_latents", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + requires_aesthetics_score: bool = False, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + else: + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + t_start = len(self.scheduler.timesteps) - num_inference_steps + timesteps = self.scheduler.timesteps[t_start:] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start) + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids + def _get_add_time_ids( + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype, + text_encoder_projection_dim=None, + ): + if self.config.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list( + negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) + ) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if ( + expected_add_embed_dim > passed_add_embed_dim + and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." + ) + elif ( + expected_add_embed_dim < passed_add_embed_dim + and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim + ): + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." + ) + elif expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def denoising_start(self): + return self._denoising_start + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.Tensor = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_start: float | None = None, + denoising_end: float | None = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: int | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(dnv, float) and 0 < dnv < 1 + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for stable-diffusion-v1-5/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py new file mode 100644 index 0000000000000000000000000000000000000000..b79119a94a0c2c7d5904cbbd1f356912bf962998 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -0,0 +1,985 @@ +# Copyright 2025 Harutatsu Akiyama and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInstructPix2PixPipeline + >>> from diffusers.utils import load_image + + >>> resolution = 768 + >>> image = load_image( + ... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + ... ).resize((resolution, resolution)) + >>> edit_instruction = "Turn sky into a cloudy one" + + >>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + ... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> edited_image = pipe( + ... prompt=edit_instruction, + ... image=image, + ... height=resolution, + ... width=resolution, + ... guidance_scale=3.0, + ... image_guidance_scale=1.5, + ... num_inference_steps=30, + ... ).images[0] + >>> edited_image + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class StableDiffusionXLInstructPix2PixPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, +): + r""" + Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`): + Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config + of `stabilityai/stable-diffusion-xl-refiner-1-0`. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + is_cosxl_edit (`bool`, *optional*): + When set the image latents are scaled. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + is_cosxl_edit: bool | None = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + self.is_cosxl_edit = is_cosxl_edit + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_latents( + self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + image_latents = image + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + image = image.float() + self.upcast_vae() + + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + if image_latents.dtype != self.vae.dtype: + image_latents = image_latents.to(dtype=self.vae.dtype) + + if self.is_cosxl_edit: + image_latents = image_latents * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 100, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + image_guidance_scale: float = 1.5, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[np.ndarray]`): + The image(s) to modify with the pipeline. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + image_guidance_scale (`float`, *optional*, defaults to 1.5): + Image guidance scale is to push the generated image towards the initial image `image`. Image guidance + scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to + generate images that are closely linked to the source image `image`, usually at the expense of lower + image quality. This pipeline requires a value of at least `1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image, height=height, width=width).to(device) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare Image latents + image_latents = self.prepare_image_latents( + image, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + do_classifier_free_guidance, + ) + + # 7. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Check that shapes of latents and image match the UNet channels + num_channels_image = image_latents.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents + num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + # The extra concat similar to how it's done in SD InstructPix2Pix. + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0) + add_text_embeds = torch.cat( + [add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0 + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance. + # The latents are expanded 3 times because for pix2pix the guidance + # is applied for both the text and the input image. + latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents + + # concat latents, image_latents in the channel dimension + scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + scaled_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_image) + + image_guidance_scale * (noise_pred_image - noise_pred_uncond) + ) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + return StableDiffusionXLPipelineOutput(images=latents) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/watermark.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..70d06bb6320d890ea90542cc7551bc42d7c59451 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + +from ...utils import is_invisible_watermark_available + + +if is_invisible_watermark_available(): + from imwatermark import WatermarkEncoder + + +# Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] + + +class StableDiffusionXLWatermarker: + def __init__(self): + self.watermark = WATERMARK_BITS + self.encoder = WatermarkEncoder() + + self.encoder.set_watermark("bits", self.watermark) + + def apply_watermark(self, images: torch.Tensor): + # can't encode images that are smaller than 256 + if images.shape[-1] < 256: + return images + + images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() + + # Convert RGB to BGR, which is the channel order expected by the watermark encoder. + images = images[:, :, :, ::-1] + + # Add watermark and convert BGR back to RGB + images = [self.encoder.encode(image, "dwtDct")[:, :, ::-1] for image in images] + + images = np.array(images) + + images = torch.from_numpy(images).permute(0, 3, 1, 2) + + images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) + return images diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_video_diffusion/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_video_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd4dc78966e217d85769691b98ed8fb0b6ac05c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_video_diffusion/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + BaseOutput, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure.update( + { + "pipeline_stable_video_diffusion": [ + "StableVideoDiffusionPipeline", + "StableVideoDiffusionPipelineOutput", + ], + } + ) + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_stable_video_diffusion import ( + StableVideoDiffusionPipeline, + StableVideoDiffusionPipelineOutput, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..05877f69d403ba6bfb4220814f40761f6d3a9996 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -0,0 +1,743 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import StableVideoDiffusionPipeline + >>> from diffusers.utils import load_image, export_to_video + + >>> pipe = StableVideoDiffusionPipeline.from_pretrained( + ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg" + ... ) + >>> image = image.resize((1024, 576)) + + >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + >>> export_to_video(frames, "generated.mp4", fps=7) + ``` +""" + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class StableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for Stable Video Diffusion pipeline. + + Args: + frames (`[list[list[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]): + list of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + num_frames, height, width, num_channels)`. + """ + + frames: list[list[PIL.Image.Image]] | np.ndarray | torch.Tensor + + +class StableVideoDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using Stable Video Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKLTemporalDecoder`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder + ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) + + def _encode_image( + self, + image: PipelineImageInput, + device: str | torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ) -> torch.Tensor: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device: str | torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + return image_latents + + def _get_add_time_ids( + self, + fps: int, + motion_bucket_id: int, + noise_aug_strength: float, + dtype: torch.dtype, + batch_size: int, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + def prepare_latents( + self, + batch_size: int, + num_frames: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: str | torch.device, + generator: torch.Generator, + latents: torch.Tensor | None = None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + if isinstance(self.guidance_scale, (int, float)): + return self.guidance_scale > 1 + return self.guidance_scale.max() > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, + height: int = 576, + width: int = 1024, + num_frames: int | None = None, + num_inference_steps: int = 25, + sigmas: list[float] | None = None, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + decode_chunk_size: int | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.Tensor`): + Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, + 1]`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for + `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`). + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. This parameter is modulated by `strength`. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after + generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + Used for conditioning the amount of motion for the generation. The higher the number the more motion + will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the + init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the + expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality. + For lower memory usage, reduce `decode_chunk_size`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `pil`, `np` or `pt`. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step during inference. The function is called + with the following arguments: + `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is + returned, otherwise a `tuple` of (`list[list[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is + returned. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = max_guidance_scale + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + + # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.video_processor.preprocess(image, height=height, width=width).to(device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 6. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, None, sigmas + ) + + # 7. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 8. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimension + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + added_time_ids=added_time_ids, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + +# resizing utils +# TODO: clean up later +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08c22a2707fe55770a519db481954881c1cad26e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_stable_diffusion_adapter"] = ["StableDiffusionAdapterPipeline"] + _import_structure["pipeline_stable_diffusion_xl_adapter"] = ["StableDiffusionXLAdapterPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_stable_diffusion_adapter import StableDiffusionAdapterPipeline + from .pipeline_stable_diffusion_xl_adapter import StableDiffusionXLAdapterPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde52c519be3d33930ce9275f7c5873aca9289e2 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20620a8fbaa1b1feef95841c758266743ee4e40c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_adapter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_xl_adapter.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_xl_adapter.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d76734e5632eac7653ba23063a6aefcfae943b3 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/__pycache__/pipeline_stable_diffusion_xl_adapter.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ffb877cfd0f66b463ea1994cfb1c944431147140 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -0,0 +1,960 @@ +# Copyright 2025 TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +@dataclass +class StableDiffusionAdapterPipelineOutput(BaseOutput): + """ + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`list[bool]`) + list of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> from diffusers.utils import load_image + >>> import torch + >>> from diffusers import StableDiffusionAdapterPipeline, T2IAdapter + + >>> image = load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png" + ... ) + + >>> color_palette = image.resize((8, 8)) + >>> color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST) + + >>> adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16) + >>> pipe = StableDiffusionAdapterPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", + ... adapter=adapter, + ... torch_dtype=torch.float16, + ... ) + + >>> pipe.to("cuda") + + >>> out_image = pipe( + ... "At night, glowing cubes in front of the beach", + ... image=color_palette, + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but receive: {image[0].ndim}" + ) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://huggingface.co/papers/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `list[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`list[float]`, *optional*, defaults to None): + list of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->adapter->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: T2IAdapter | MultiAdapter | list[T2IAdapter], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(adapter, (list, tuple)): + adapter = MultiAdapter(adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + adapter=adapter, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if isinstance(self.adapter, MultiAdapter): + if not isinstance(image, list): + raise ValueError( + "MultiAdapter is enabled, but `image` is not a list. Please pass a list of images to `image`." + ) + + if len(image) != len(self.adapter.adapters): + raise ValueError( + f"MultiAdapter requires passing the same number of images as adapters. Given {len(image)} images and {len(self.adapter.adapters)} adapters." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor + + return height, width + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: torch.Tensor | PIL.Image.Image | list[PIL.Image.Image] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + adapter_conditioning_scale: float | list[float] = 1.0, + clip_skip: int | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[list[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + adapter_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + self._guidance_scale = guidance_scale + + if isinstance(self.adapter, MultiAdapter): + adapter_input = [] + + for one_image in image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=self.adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(image, height, width) + adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if self.do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + down_intrablock_additional_residuals=[state.clone() for state in adapter_state], + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..a6dd07847de2df5e6a1e325ac512b02af44875b2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -0,0 +1,1303 @@ +# Copyright 2025 TencentARC and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler + >>> from diffusers.utils import load_image + + >>> sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L") + + >>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" + + >>> adapter = T2IAdapter.from_pretrained( + ... "Adapter/t2iadapter", + ... subfolder="sketch_sdxl_1.0", + ... torch_dtype=torch.float16, + ... adapter_type="full_adapter_xl", + ... ) + >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") + + >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained( + ... model_id, adapter=adapter, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler + ... ).to("cuda") + + >>> generator = torch.manual_seed(42) + >>> sketch_image_out = pipe( + ... prompt="a photo of a dog in real world, high quality", + ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality", + ... image=sketch_image, + ... generator=generator, + ... guidance_scale=7.5, + ... ).images[0] + ``` +""" + + +def _preprocess_adapter_image(image, height, width): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image] + image = [ + i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image + ] # expand [h, w] or [h, w, c] to [b, h, w, c] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + if image[0].ndim == 3: + image = torch.stack(image, dim=0) + elif image[0].ndim == 4: + image = torch.cat(image, dim=0) + else: + raise ValueError( + f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but receive: {image[0].ndim}" + ) + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLAdapterPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter + https://huggingface.co/papers/2302.08453 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + adapter ([`T2IAdapter`] or [`MultiAdapter`] or `list[T2IAdapter]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a + list, the outputs from each Adapter are added together to create one combined additional conditioning. + adapter_weights (`list[float]`, *optional*, defaults to None): + list of floats representing the weight which will be multiply to each adapter's output before adding them + together. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + adapter: T2IAdapter | MultiAdapter | list[T2IAdapter], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + adapter=adapter, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[-2] + + # round down to nearest multiple of `self.adapter.downscale_factor` + height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[-1] + + # round down to nearest multiple of `self.adapter.downscale_factor` + width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor + + return height, width + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + prompt_2: str | list[str] | None = None, + image: PipelineImageInput = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: list[int] = None, + sigmas: list[float] = None, + denoising_end: float | None = None, + guidance_scale: float = 5.0, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[torch.Tensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + negative_original_size: tuple[int, int] | None = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: tuple[int, int] | None = None, + adapter_conditioning_scale: float | list[float] = 1.0, + adapter_conditioning_factor: float = 1.0, + clip_skip: int | None = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`torch.Tensor`, `PIL.Image.Image`, `list[torch.Tensor]` or `list[PIL.Image.Image]` or `list[list[PIL.Image.Image]]`): + The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the + type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be + accepted as an image. The control image is automatically resized to fit the output image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`] + instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + adapter_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0): + The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the + residual in the original unet. If multiple adapters are specified in init, you can set the + corresponding scale as a list. + adapter_conditioning_factor (`float`, *optional*, defaults to 1.0): + The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is + `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for + all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + + height, width = self._default_height_width(height, width, image) + device = self._execution_device + + if isinstance(self.adapter, MultiAdapter): + adapter_input = [] + + for one_image in image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=self.adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(image, height, width) + adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + ) + + self._guidance_scale = guidance_scale + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3.1 Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare added time ids & embeddings & adapter features + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale + if num_images_per_prompt > 1: + for k, v in enumerate(adapter_state): + adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) + if self.do_classifier_free_guidance: + for k, v in enumerate(adapter_state): + adapter_state[k] = torch.cat([v] * 2, dim=0) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + if i < int(num_inference_steps * adapter_conditioning_factor): + down_intrablock_additional_residuals = [state.clone() for state in adapter_state] + else: + down_intrablock_additional_residuals = None + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=cross_attention_kwargs, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8fdb92769bb3dcb2dd7696115b16f197062262 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["TextToVideoSDPipelineOutput"] + _import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"] + _import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"] + _import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"] + _import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import TextToVideoSDPipelineOutput + from .pipeline_text_to_video_synth import TextToVideoSDPipeline + from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline + from .pipeline_text_to_video_zero import TextToVideoZeroPipeline + from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed765d71efadb918914f9b4980909af4db74fa4 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dd5d58ab1c45c8b8d7b1c03b61fdedf06fc1c26 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fcb3b5e57f835e1f532a94c896f0690da12b34c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d512ed90a7064c9fd42c703f76b6e32bc804efb5 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_synth_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba8d8446093184def98dd7cd254839ee9d63fdcb Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero_sdxl.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero_sdxl.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd439904032478a3a2f3ce78b5ac1ac97f62ffbc Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/__pycache__/pipeline_text_to_video_zero_sdxl.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..c94c5d2d144a5c3c26fe247a11e28bfacc7a4832 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass + +import numpy as np +import PIL +import torch + +from ...utils import ( + BaseOutput, +) + + +@dataclass +class TextToVideoSDPipelineOutput(BaseOutput): + """ + Output class for text-to-video pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised + PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)` + """ + + frames: torch.Tensor | np.ndarray | list[list[PIL.Image.Image]] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca64d33acda4d96edec3c20f78a7434f729f8d2 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -0,0 +1,659 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from . import TextToVideoSDPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import TextToVideoSDPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = TextToVideoSDPipeline.from_pretrained( + ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" + ... ) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Spiderman is surfing" + >>> video_frames = pipe(prompt).frames[0] + >>> video_path = export_to_video(video_frames) + >>> video_path + ``` +""" + + +class TextToVideoSDPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, +): + _last_supported_version = "0.33.1" + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet3DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_frames: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 9.0, + negative_prompt: str | list[str] | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `torch.Tensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 9. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..6908f51eb21b14adae8e85775389f0d950aa40c8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -0,0 +1,716 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from . import TextToVideoSDPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + >>> from diffusers.utils import export_to_video + + >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> prompt = "spiderman running in the desert" + >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames[0] + >>> # safe low-res video + >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") + + >>> # let's offload the text-to-image model + >>> pipe.to("cpu") + + >>> # and load the image-to-image model + >>> pipe = DiffusionPipeline.from_pretrained( + ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" + ... ) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode + >>> pipe.vae.enable_slicing() + + >>> # now let's upscale it + >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] + + >>> # and denoise it + >>> video_frames = pipe(prompt, video=video, strength=0.6).frames[0] + >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") + >>> video_path + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class VideoToVideoSDPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, +): + _last_supported_version = "0.33.1" + r""" + Pipeline for text-guided video-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet3DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): + video = video.to(device=device, dtype=dtype) + + # change from (b, c, f, h, w) -> (b * f, c, w, h) + bsz, channel, frames, width, height = video.shape + video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + if video.shape[1] == 4: + init_latents = video + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + elif isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = retrieve_latents(self.vae.encode(video), generator=generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + video: list[np.ndarray] | torch.Tensor = None, + strength: float = 0.6, + num_inference_steps: int = 50, + guidance_scale: float = 15.0, + negative_prompt: str | list[str] | None = None, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + clip_skip: int | None = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + video (`list[np.ndarray]` or `torch.Tensor`): + `video` frames or tensor representing a video batch to be used as the starting point for the process. + Can also accept video latents as `image`, if passing latents directly, it will not be encoded again. + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `video`. Must be between 0 and 1. `video` is used as a + starting point, adding more noise to it the larger the `strength`. The number of denoising steps + depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the + denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of + 1 essentially ignores `video`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in video generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `torch.Tensor` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to unet + num_images_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess video + video = self.video_processor.preprocess_video(video) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # reshape latents + bsz, channel, frames, width, height = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # reshape latents back + latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + + # 9. Post processing + if output_type == "latent": + video = latents + else: + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) + + # 10. Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return TextToVideoSDPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..66defb2f37454eb0198d96ea7c2f9df2e12eb82f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -0,0 +1,1004 @@ +import copy +import inspect +from dataclasses import dataclass +from typing import Callable + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from torch.nn.functional import grid_sample +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import empty_device_cache, randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin +from ..stable_diffusion import StableDiffusionSafetyChecker + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def rearrange_0(tensor, f): + F, C, H, W = tensor.size() + tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) + return tensor + + +def rearrange_1(tensor): + B, C, F, H, W = tensor.size() + return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) + + +def rearrange_3(tensor, f): + F, D, C = tensor.size() + return torch.reshape(tensor, (F // f, f, D, C)) + + +def rearrange_4(tensor): + B, F, D, C = tensor.size() + return torch.reshape(tensor, (B * F, D, C)) + + +class CrossFrameAttnProcessor: + """ + Cross frame attention processor. Each frame attends the first frame. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class CrossFrameAttnProcessor2_0: + """ + Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = max(1, key.size()[0] // self.batch_size) + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +@dataclass +class TextToVideoPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + images (`[list[PIL.Image.Image]`, `np.ndarray`]): + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + nsfw_content_detected (`[list[bool]]`): + list indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or + `None` if safety checking could not be performed. + """ + + images: list[PIL.Image.Image] | np.ndarray + nsfw_content_detected: list[bool] | None + + +def coords_grid(batch, ht, wd, device): + # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def warp_single_latent(latent, reference_flow): + """ + Warp latent of a single frame with given flow + + Args: + latent: latent code of a single frame + reference_flow: flow which to warp the latent with + + Returns: + warped: warped latent + """ + _, _, H, W = reference_flow.size() + _, _, h, w = latent.size() + coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) + + coords_t0 = coords0 + reference_flow + coords_t0[:, 0] /= W + coords_t0[:, 1] /= H + + coords_t0 = coords_t0 * 2.0 - 1.0 + coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") + coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) + + warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") + return warped + + +def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): + """ + Create translation motion field + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + device: device + dtype: dtype + + Returns: + + """ + seq_length = len(frame_ids) + reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) + for fr_idx in range(seq_length): + reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) + reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) + return reference_flow + + +def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): + """ + Creates translation motion and warps the latents accordingly + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + latents: latent codes of frames + + Returns: + warped_latents: warped latents + """ + motion_field = create_motion_field( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + frame_ids=frame_ids, + device=latents.device, + dtype=latents.dtype, + ) + warped_latents = latents.clone().detach() + for i in range(len(warped_latents)): + warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) + return warped_latents + + +class TextToVideoZeroPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + FromSingleFileMixin, +): + _last_supported_version = "0.33.1" + r""" + Pipeline for zero-shot text-to-video generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet3DConditionModel`] to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for + more details about a model's potential harms. + feature_extractor ([`CLIPImageProcessor`]): + A [`CLIPImageProcessor`] to extract features from generated images; used as inputs to the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def forward_loop(self, x_t0, t0, t1, generator): + """ + Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance. + + Args: + x_t0: + Latent code at time t0. + t0: + Timestep at t0. + t1: + Timestamp at t1. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + + Returns: + x_t1: + Forward process applied to x_t0 from time t0 to t1. + """ + eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) + alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) + x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps + return x_t1 + + def backward_loop( + self, + latents, + timesteps, + prompt_embeds, + guidance_scale, + callback, + callback_steps, + num_warmup_steps, + extra_step_kwargs, + cross_attention_kwargs=None, + ): + """ + Perform backward process given list of time steps. + + Args: + latents: + Latents at time timesteps[0]. + timesteps: + Time steps along which to perform backward process. + prompt_embeds: + Pre-generated text embeddings. + guidance_scale: + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + extra_step_kwargs: + Extra_step_kwargs. + cross_attention_kwargs: + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + num_warmup_steps: + number of warmup steps. + + Returns: + latents: + Latents of backward process output at time timesteps[-1]. + """ + do_classifier_free_guidance = guidance_scale > 1.0 + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + with self.progress_bar(total=num_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + return latents.clone().detach() + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + video_length: int | None = 8, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + motion_field_strength_x: float = 12, + motion_field_strength_y: float = 12, + output_type: str | None = "tensor", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int | None = 1, + t0: int = 44, + t1: int = 47, + frame_ids: list[int] | None = None, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + video_length (`int`, *optional*, defaults to 8): + The number of generated video frames. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in video generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"latent"` and `"np"`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a + [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`] instead of + a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + motion_field_strength_x (`float`, *optional*, defaults to 12): + Strength of motion in generated video along x-axis. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + motion_field_strength_y (`float`, *optional*, defaults to 12): + Strength of motion in generated video along y-axis. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + t0 (`int`, *optional*, defaults to 44): + Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + t1 (`int`, *optional*, defaults to 47): + Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + frame_ids (`list[int]`, *optional*): + Indexes of the frames that are being generated. This is used when generating longer videos + chunk-by-chunk. + + Returns: + [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput`]: + The output contains a `ndarray` of the generated video, when `output_type` != `"latent"`, otherwise a + latent code of generated videos and a list of `bool`s indicating whether the corresponding generated + video contains "not-safe-for-work" (nsfw) content.. + """ + assert video_length > 0 + if frame_ids is None: + frame_ids = list(range(video_length)) + assert len(frame_ids) == video_length + + assert num_videos_per_prompt == 1 + + # set the processor + original_attn_proc = self.unet.attn_processors + processor = ( + CrossFrameAttnProcessor2_0(batch_size=2) + if hasattr(F, "scaled_dot_product_attention") + else CrossFrameAttnProcessor(batch_size=2) + ) + self.unet.set_attn_processor(processor) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + # Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode input prompt + prompt_embeds_tuple = self.encode_prompt( + prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + ) + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + # Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Perform the first backward process up to time T_1 + x_1_t1 = self.backward_loop( + timesteps=timesteps[: -t1 - 1], + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=num_warmup_steps, + ) + scheduler_copy = copy.deepcopy(self.scheduler) + + # Perform the second backward process up to time T_0 + x_1_t0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 : -t0 - 1], + prompt_embeds=prompt_embeds, + latents=x_1_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + + # Propagate first frame latents at time T_0 to remaining frames + x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) + + # Add motion in latents at time T_0 + x_2k_t0 = create_motion_field_and_warp_latents( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + latents=x_2k_t0, + frame_ids=frame_ids[1:], + ) + + # Perform forward process up to time T_1 + x_2k_t1 = self.forward_loop( + x_t0=x_2k_t0, + t0=timesteps[-t0 - 1].item(), + t1=timesteps[-t1 - 1].item(), + generator=generator, + ) + + # Perform backward process from time T_1 to 0 + x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) + b, l, d = prompt_embeds.size() + prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) + + self.scheduler = scheduler_copy + x_1k_0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 :], + prompt_embeds=prompt_embeds, + latents=x_1k_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + ) + latents = x_1k_0 + + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + empty_device_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + image = self.decode_latents(latents) + # Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload all models + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image, has_nsfw_concept) + + return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..a3286cd940fd4d339c1d2dfd257964a9a7564427 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py @@ -0,0 +1,1340 @@ +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from torch.nn.functional import grid_sample +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + is_invisible_watermark_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_0 +def rearrange_0(tensor, f): + F, C, H, W = tensor.size() + tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) + return tensor + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_1 +def rearrange_1(tensor): + B, C, F, H, W = tensor.size() + return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_3 +def rearrange_3(tensor, f): + F, D, C = tensor.size() + return torch.reshape(tensor, (F // f, f, D, C)) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.rearrange_4 +def rearrange_4(tensor): + B, F, D, C = tensor.size() + return torch.reshape(tensor, (B * F, D, C)) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor +class CrossFrameAttnProcessor: + """ + Cross frame attention processor. Each frame attends the first frame. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = key.size()[0] // self.batch_size + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor2_0 +class CrossFrameAttnProcessor2_0: + """ + Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. + + Args: + batch_size: The number that represents actual batch size, other than the frames. + For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to + 2, due to classifier-free guidance. + """ + + def __init__(self, batch_size=2): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.batch_size = batch_size + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + is_cross_attention = encoder_hidden_states is not None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # Cross Frame Attention + if not is_cross_attention: + video_length = max(1, key.size()[0] // self.batch_size) + first_frame_index = [0] * video_length + + # rearrange keys to have batch and frames in the 1st and 2nd dims respectively + key = rearrange_3(key, video_length) + key = key[:, first_frame_index] + # rearrange values to have batch and frames in the 1st and 2nd dims respectively + value = rearrange_3(value, video_length) + value = value[:, first_frame_index] + + # rearrange back to original shape + key = rearrange_4(key) + value = rearrange_4(value) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +@dataclass +class TextToVideoSDXLPipelineOutput(BaseOutput): + """ + Output class for zero-shot text-to-video pipeline. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image] | np.ndarray + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.coords_grid +def coords_grid(batch, ht, wd, device): + # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.warp_single_latent +def warp_single_latent(latent, reference_flow): + """ + Warp latent of a single frame with given flow + + Args: + latent: latent code of a single frame + reference_flow: flow which to warp the latent with + + Returns: + warped: warped latent + """ + _, _, H, W = reference_flow.size() + _, _, h, w = latent.size() + coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) + + coords_t0 = coords0 + reference_flow + coords_t0[:, 0] /= W + coords_t0[:, 1] /= H + + coords_t0 = coords_t0 * 2.0 - 1.0 + coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") + coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) + + warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") + return warped + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field +def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): + """ + Create translation motion field + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + device: device + dtype: dtype + + Returns: + + """ + seq_length = len(frame_ids) + reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) + for fr_idx in range(seq_length): + reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) + reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) + return reference_flow + + +# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.create_motion_field_and_warp_latents +def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): + """ + Creates translation motion and warps the latents accordingly + + Args: + motion_field_strength_x: motion strength along x-axis + motion_field_strength_y: motion strength along y-axis + frame_ids: indexes of the frames the latents of which are being processed. + This is needed when we perform chunk-by-chunk inference + latents: latent codes of frames + + Returns: + warped_latents: warped latents + """ + motion_field = create_motion_field( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + frame_ids=frame_ids, + device=latents.device, + dtype=latents.dtype, + ) + warped_latents = latents.clone().detach() + for i in range(len(warped_latents)): + warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) + return warped_latents + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class TextToVideoZeroSDXLPipeline( + DeprecatedPipelineMixin, + DiffusionPipeline, + StableDiffusionMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +): + _last_supported_version = "0.33.1" + r""" + Pipeline for zero-shot text-to-video generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: bool | None = None, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae + def upcast_vae(self): + deprecate( + "upcast_vae", + "1.0.0", + "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.", + ) + self.vae.to(dtype=torch.float32) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: str | None = None, + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | None = None, + negative_prompt_2: str | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: list[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoZeroPipeline.forward_loop + def forward_loop(self, x_t0, t0, t1, generator): + """ + Perform DDPM forward process from time t0 to t1. This is the same as adding noise with corresponding variance. + + Args: + x_t0: + Latent code at time t0. + t0: + Timestep at t0. + t1: + Timestamp at t1. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + + Returns: + x_t1: + Forward process applied to x_t0 from time t0 to t1. + """ + eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) + alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) + x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps + return x_t1 + + def backward_loop( + self, + latents, + timesteps, + prompt_embeds, + guidance_scale, + callback, + callback_steps, + num_warmup_steps, + extra_step_kwargs, + add_text_embeds, + add_time_ids, + cross_attention_kwargs=None, + guidance_rescale: float = 0.0, + ): + """ + Perform backward process given list of time steps + + Args: + latents: + Latents at time timesteps[0]. + timesteps: + Time steps along which to perform backward process. + prompt_embeds: + Pre-generated text embeddings. + guidance_scale: + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + extra_step_kwargs: + Extra_step_kwargs. + cross_attention_kwargs: + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + num_warmup_steps: + number of warmup steps. + + Returns: + latents: latents of backward process output at time timesteps[-1] + """ + + do_classifier_free_guidance = guidance_scale > 1.0 + num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order + + with self.progress_bar(total=num_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + return latents.clone().detach() + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + prompt_2: str | list[str] | None = None, + video_length: int | None = 8, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + denoising_end: float | None = None, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + num_videos_per_prompt: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + frame_ids: list[int] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + pooled_prompt_embeds: torch.Tensor | None = None, + negative_pooled_prompt_embeds: torch.Tensor | None = None, + latents: torch.Tensor | None = None, + motion_field_strength_x: float = 12, + motion_field_strength_y: float = 12, + output_type: str | None = "tensor", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + cross_attention_kwargs: dict[str, Any] | None = None, + guidance_rescale: float = 0.0, + original_size: tuple[int, int] | None = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: tuple[int, int] | None = None, + t0: int = 44, + t1: int = 47, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + video_length (`int`, *optional*, defaults to 8): + The number of generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + frame_ids (`list[int]`, *optional*): + Indexes of the frames that are being generated. This is used when generating longer videos + chunk-by-chunk. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + motion_field_strength_x (`float`, *optional*, defaults to 12): + Strength of motion in generated video along x-axis. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + motion_field_strength_y (`float`, *optional*, defaults to 12): + Strength of motion in generated video along y-axis. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + t0 (`int`, *optional*, defaults to 44): + Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + t1 (`int`, *optional*, defaults to 47): + Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the + [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1. + + Returns: + [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] or + `tuple`: [`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + assert video_length > 0 + if frame_ids is None: + frame_ids = list(range(video_length)) + assert len(frame_ids) == video_length + + assert num_videos_per_prompt == 1 + + # set the processor + original_attn_proc = self.unet.attn_processors + processor = ( + CrossFrameAttnProcessor2_0(batch_size=2) + if hasattr(F, "scaled_dot_product_attention") + else CrossFrameAttnProcessor(batch_size=2) + ) + self.unet.set_attn_processor(processor) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + batch_size = ( + 1 if isinstance(prompt, str) else len(prompt) if isinstance(prompt, list) else prompt_embeds.shape[0] + ) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Perform the first backward process up to time T_1 + x_1_t1 = self.backward_loop( + timesteps=timesteps[: -t1 - 1], + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=num_warmup_steps, + add_text_embeds=add_text_embeds, + add_time_ids=add_time_ids, + ) + + scheduler_copy = copy.deepcopy(self.scheduler) + + # Perform the second backward process up to time T_0 + x_1_t0 = self.backward_loop( + timesteps=timesteps[-t1 - 1 : -t0 - 1], + prompt_embeds=prompt_embeds, + latents=x_1_t1, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + add_text_embeds=add_text_embeds, + add_time_ids=add_time_ids, + ) + + # Propagate first frame latents at time T_0 to remaining frames + x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) + + # Add motion in latents at time T_0 + x_2k_t0 = create_motion_field_and_warp_latents( + motion_field_strength_x=motion_field_strength_x, + motion_field_strength_y=motion_field_strength_y, + latents=x_2k_t0, + frame_ids=frame_ids[1:], + ) + + # Perform forward process up to time T_1 + x_2k_t1 = self.forward_loop( + x_t0=x_2k_t0, + t0=timesteps[-t0 - 1].to(torch.long), + t1=timesteps[-t1 - 1].to(torch.long), + generator=generator, + ) + + # Perform backward process from time T_1 to 0 + latents = torch.cat([x_1_t1, x_2k_t1]) + + self.scheduler = scheduler_copy + timesteps = timesteps[-t1 - 1 :] + + b, l, d = prompt_embeds.size() + prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) + + b, k = add_text_embeds.size() + add_text_embeds = add_text_embeds[:, None].repeat(1, video_length, 1).reshape(b * video_length, k) + + b, k = add_time_ids.size() + add_time_ids = add_time_ids[:, None].repeat(1, video_length, 1).reshape(b * video_length, k) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + x_1k_0 = self.backward_loop( + timesteps=timesteps, + prompt_embeds=prompt_embeds, + latents=latents, + guidance_scale=guidance_scale, + callback=callback, + callback_steps=callback_steps, + extra_step_kwargs=extra_step_kwargs, + num_warmup_steps=0, + add_text_embeds=add_text_embeds, + add_time_ids=add_time_ids, + ) + + latents = x_1k_0 + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + return TextToVideoSDXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + # make sure to set the original attention processors back + self.unet.set_attn_processor(original_attn_proc) + + if not return_dict: + return (image,) + + return TextToVideoSDXLPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c89e899463beede59b8ccf02688f6168b8ee3d77 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline + + _dummy_objects.update( + {"UnCLIPImageVariationPipeline": UnCLIPImageVariationPipeline, "UnCLIPPipeline": UnCLIPPipeline} + ) +else: + _import_structure["pipeline_unclip"] = ["UnCLIPPipeline"] + _import_structure["pipeline_unclip_image_variation"] = ["UnCLIPImageVariationPipeline"] + _import_structure["text_proj"] = ["UnCLIPTextProjModel"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_unclip import UnCLIPPipeline + from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline + from .text_proj import UnCLIPTextProjModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/pipeline_unclip.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/pipeline_unclip.py new file mode 100644 index 0000000000000000000000000000000000000000..430f1a1e52651b9f98174a081567f463235fcee3 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/pipeline_unclip.py @@ -0,0 +1,503 @@ +# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch +from torch.nn import functional as F +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...schedulers import UnCLIPScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .text_proj import UnCLIPTextProjModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + prior ([`PriorTransformer`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution UNet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution UNet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]). + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). + + """ + + _last_supported_version = "0.33.1" + _exclude_from_cpu_offload = ["prior"] + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + model_cpu_offload_seq = "text_encoder->text_proj->decoder->super_res_first->super_res_last" + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: CLIPTextModelOutput | tuple | None = None, + text_attention_mask: torch.Tensor | None = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_enc_hid_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_enc_hid_states.shape[1] + uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1) + uncond_text_enc_hid_states = uncond_text_enc_hid_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_enc_hid_states, text_mask + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: torch.Generator | list[torch.Generator] | None = None, + prior_latents: torch.Tensor | None = None, + decoder_latents: torch.Tensor | None = None, + super_res_latents: torch.Tensor | None = None, + text_model_output: CLIPTextModelOutput | tuple | None = None, + text_attention_mask: torch.Tensor | None = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output` + and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prior_latents (`torch.Tensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text + outputs can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_enc_hid_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_enc_hid_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_enc_hid_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_enc_hid_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_enc_hid_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = super_res_latents + # done super res + + self.maybe_free_model_hooks() + + # post processing + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d8bdc447879151ea401e42062a8236618a815c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -0,0 +1,430 @@ +# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import PIL.Image +import torch +from torch.nn import functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import UNet2DConditionModel, UNet2DModel +from ...schedulers import UnCLIPScheduler +from ...utils import is_torch_xla_available, logging +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .text_proj import UnCLIPTextProjModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPImageVariationPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline to generate image variations from an input image using UnCLIP. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`~transformers.CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution UNet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution UNet. Used in the last step of the super resolution diffusion process. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]). + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]). + """ + + _last_supported_version = "0.33.1" + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + feature_extractor: CLIPImageProcessor + image_encoder: CLIPVisionModelWithProjection + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + model_cpu_offload_seq = "text_encoder->image_encoder->text_proj->decoder->super_res_first->super_res_last" + + def __init__( + self, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + super_res_first=super_res_first, + super_res_last=super_res_last, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: torch.Tensor | None = None): + dtype = next(self.image_encoder.parameters()).dtype + + if image_embeddings is None: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + + image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + return image_embeddings + + @torch.no_grad() + def __call__( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor | None = None, + num_images_per_prompt: int = 1, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: torch.Generator | None = None, + decoder_latents: torch.Tensor | None = None, + super_res_latents: torch.Tensor | None = None, + image_embeddings: torch.Tensor | None = None, + decoder_guidance_scale: float = 8.0, + output_type: str | None = "pil", + return_dict: bool = True, + ): + """ + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `list[PIL.Image.Image]` or `torch.Tensor`): + `Image` or tensor representing an image batch to be used as the starting point. If you provide a + tensor, it needs to be compatible with the [`CLIPImageProcessor`] + [configuration](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + Can be left as `None` only when `image_embeddings` are passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + image_embeddings (`torch.Tensor`, *optional*): + Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings + can be passed for tasks like image interpolations. `image` can be left as `None`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + if image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + else: + batch_size = image_embeddings.shape[0] + + prompt = [""] * batch_size + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = decoder_guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance + ) + + image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) + + # decoder + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + prompt_embeds=prompt_embeds, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + if device.type == "mps": + # HACK: MPS: There is a panic when padding bool tensors, + # so cast to int tensor for the pad and back to bool afterwards + text_mask = text_mask.type(torch.int) + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + decoder_text_mask = decoder_text_mask.type(torch.bool) + else: + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.config.in_channels + height = self.decoder.config.sample_size + width = self.decoder.config.sample_size + + if decoder_latents is None: + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.config.in_channels // 2 + height = self.super_res_first.config.sample_size + width = self.super_res_first.config.sample_size + + if super_res_latents is None: + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + if device.type == "mps": + # MPS does not support many interpolations + image_upscaled = F.interpolate(image_small, size=[height, width]) + else: + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + if XLA_AVAILABLE: + xm.mark_step() + + image = super_res_latents + + # done super res + self.maybe_free_model_hooks() + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/text_proj.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/text_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..5e04e48ba621caa4a617e52c5b83a1573bb11975 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unclip/text_proj.py @@ -0,0 +1,86 @@ +# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +class UnCLIPTextProjModel(ModelMixin, ConfigMixin): + """ + Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the + decoder. + + For more details, see the original paper: https://huggingface.co/papers/2204.06125 section 2.1 + """ + + @register_to_config + def __init__( + self, + *, + clip_extra_context_tokens: int = 4, + clip_embeddings_dim: int = 768, + time_embed_dim: int, + cross_attention_dim, + ): + super().__init__() + + self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) + + # parameters for additional clip time embeddings + self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) + self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) + + # parameters for encoder hidden states + self.clip_extra_context_tokens = clip_extra_context_tokens + self.clip_extra_context_tokens_proj = nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim + ) + self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) + self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): + if do_classifier_free_guidance: + # Add the classifier free guidance embeddings to the image embeddings + image_embeddings_batch_size = image_embeddings.shape[0] + classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) + classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( + image_embeddings_batch_size, -1 + ) + image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) + + # The image embeddings batch size and the text embeddings batch size are equal + assert image_embeddings.shape[0] == prompt_embeds.shape[0] + + batch_size = prompt_embeds.shape[0] + + # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and + # adding CLIP embeddings to the existing timestep embedding, ... + time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) + time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) + additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds + + # ... and by projecting CLIP embeddings into four + # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" + clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) + clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) + clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) + + text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) + text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) + text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) + + return text_encoder_hidden_states, additive_clip_time_embeddings diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac2b09a6e570087c80bc11bf1a8102dd4970b8f --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) + + _dummy_objects.update( + {"ImageTextPipelineOutput": ImageTextPipelineOutput, "UniDiffuserPipeline": UniDiffuserPipeline} + ) +else: + _import_structure["modeling_text_decoder"] = ["UniDiffuserTextDecoder"] + _import_structure["modeling_uvit"] = ["UniDiffuserModel", "UTransformer2DModel"] + _import_structure["pipeline_unidiffuser"] = ["ImageTextPipelineOutput", "UniDiffuserPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) + else: + from .modeling_text_decoder import UniDiffuserTextDecoder + from .modeling_uvit import UniDiffuserModel, UTransformer2DModel + from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/modeling_text_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c68b5d9ab5a84e9faa32408fae36726ae9427fee --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -0,0 +1,294 @@ +import numpy as np +import torch +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers.modeling_utils import ModuleUtilsMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + """ + Text decoder model for a image-text [UniDiffuser](https://huggingface.co/papers/2303.06555) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + prefix_length (`int`): + Max number of prefix tokens that will be supplied to the model. + prefix_inner_dim (`int`): + The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + CLIP text encoder. + prefix_hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + """ + + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] + + @register_to_config + def __init__( + self, + prefix_length: int, + prefix_inner_dim: int, + prefix_hidden_dim: int | None = None, + vocab_size: int = 50257, # Start of GPT2 config args + n_positions: int = 1024, + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + n_inner: int | None = None, + activation_function: str = "gelu_new", + resid_pdrop: float = 0.1, + embd_pdrop: float = 0.1, + attn_pdrop: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, + scale_attn_weights: bool = True, + use_cache: bool = True, + scale_attn_by_inverse_layer_idx: bool = False, + reorder_and_upcast_attn: bool = False, + ): + super().__init__() + + self.prefix_length = prefix_length + + if prefix_inner_dim != n_embd and prefix_hidden_dim is None: + raise ValueError( + f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" + f" `n_embd`: {n_embd} are not equal." + ) + + self.prefix_inner_dim = prefix_inner_dim + self.prefix_hidden_dim = prefix_hidden_dim + + self.encode_prefix = ( + nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) + if self.prefix_hidden_dim is not None + else nn.Identity() + ) + self.decode_prefix = ( + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + ) + + gpt_config = GPT2Config( + vocab_size=vocab_size, + n_positions=n_positions, + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, + activation_function=activation_function, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attn_pdrop=attn_pdrop, + layer_norm_epsilon=layer_norm_epsilon, + initializer_range=initializer_range, + scale_attn_weights=scale_attn_weights, + use_cache=use_cache, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + self.transformer = GPT2LMHeadModel(gpt_config) + + def forward( + self, + input_ids: torch.Tensor, + prefix_embeds: torch.Tensor, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + ): + """ + Args: + input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): + Text tokens to use for inference. + prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): + Prefix embedding to prepend to the embedded tokens. + attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): + Attention mask for the prefix embedding. + labels (`torch.Tensor`, *optional*): + Labels to use for language modeling. + """ + embedding_text = self.transformer.transformer.wte(input_ids) + hidden = self.encode_prefix(prefix_embeds) + prefix_embeds = self.decode_prefix(hidden) + embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) + + if labels is not None: + dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) + labels = torch.cat((dummy_token, input_ids), dim=1) + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) + if self.prefix_hidden_dim is not None: + return out, hidden + else: + return out + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def encode(self, prefix): + return self.encode_prefix(prefix) + + @torch.no_grad() + def generate_captions(self, features, eos_token_id, device): + """ + Generate captions given text embedding features. Returns list[L]. + + Args: + features (`torch.Tensor` of shape `(B, L, D)`): + Text embedding features to generate captions from. + eos_token_id (`int`): + The token ID of the EOS token for the text decoder model. + device: + Device to perform text generation on. + + Returns: + `list[str]`: A list of strings generated from the decoder model. + """ + + features = torch.split(features, 1, dim=0) + generated_tokens = [] + generated_seq_lengths = [] + for feature in features: + feature = self.decode_prefix(feature.to(device)) # back to the clip feature + # Only support beam search for now + output_tokens, seq_lengths = self.generate_beam( + input_embeds=feature, device=device, eos_token_id=eos_token_id + ) + generated_tokens.append(output_tokens[0]) + generated_seq_lengths.append(seq_lengths[0]) + generated_tokens = torch.stack(generated_tokens) + generated_seq_lengths = torch.stack(generated_seq_lengths) + return generated_tokens, generated_seq_lengths + + @torch.no_grad() + def generate_beam( + self, + input_ids=None, + input_embeds=None, + device=None, + beam_size: int = 5, + entry_length: int = 67, + temperature: float = 1.0, + eos_token_id: int | None = None, + ): + """ + Generates text using the given tokenizer and text prompt or token embedding via beam search. This + implementation is based on the beam search implementation from the [original UniDiffuser + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). + + Args: + eos_token_id (`int`, *optional*): + The token ID of the EOS token for the text decoder model. + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` + must be supplied. + input_embeds (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + An embedded representation to directly pass to the transformer as a prefix for beam search. One of + `input_ids` and `input_embeds` must be supplied. + device: + The device to perform beam search on. + beam_size (`int`, *optional*, defaults to `5`): + The number of best states to store during beam search. + entry_length (`int`, *optional*, defaults to `67`): + The number of iterations to run beam search. + temperature (`float`, *optional*, defaults to 1.0): + The temperature to use when performing the softmax over logits from the decoding model. + + Returns: + `tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated + token sequences sorted by score in descending order, and the second element is the sequence lengths + corresponding to those sequences. + """ + # Generates text until stop_token is reached using beam search with the desired beam size. + stop_token_index = eos_token_id + tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + + if input_embeds is not None: + generated = input_embeds + else: + generated = self.transformer.transformer.wte(input_ids) + + for i in range(entry_length): + outputs = self.transformer(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + + scores = scores / seq_lengths + order = scores.argsort(descending=True) + # tokens tensors are already padded to max_seq_length + output_texts = [tokens[i] for i in order] + output_texts = torch.stack(output_texts, dim=0) + seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) + return output_texts, seq_lengths diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/modeling_uvit.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/modeling_uvit.py new file mode 100644 index 0000000000000000000000000000000000000000..125188196c1eabc3e07f890e8bd7f97cefa41424 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -0,0 +1,1196 @@ +import math + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.attention import FeedForward +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed +from ...models.modeling_outputs import Transformer2DModelOutput +from ...models.normalization import AdaLayerNorm +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect." + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, + \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for + generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + use_pos_embed=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.use_pos_embed = use_pos_embed + if self.use_pos_embed: + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt") + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.use_pos_embed: + return latent + self.pos_embed + else: + return latent + + +class SkipBlock(nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.skip_linear = nn.Linear(2 * dim, dim) + + # Use torch.nn.LayerNorm for now, following the original code + self.norm = nn.LayerNorm(dim) + + def forward(self, x, skip): + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = self.norm(x) + + return x + + +# Modified to support both pre-LayerNorm and post-LayerNorm configurations +# Don't support AdaLayerNormZero for now +# Modified from diffusers.models.attention.BasicTransformerBlock +class UTransformerBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. + `pre_layer_norm = True`. + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = True, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + else: + norm_hidden_states = self.norm1(hidden_states) + else: + norm_hidden_states = hidden_states + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + attn_output = self.norm1(attn_output, timestep) + else: + attn_output = self.norm1(attn_output) + + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + else: + norm_hidden_states = hidden_states + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = self.norm3(hidden_states) + else: + norm_hidden_states = hidden_states + + ff_output = self.ff(norm_hidden_states) + + # Post-LayerNorm + if not self.pre_layer_norm: + ff_output = self.norm3(ff_output) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +# Like UTransformerBlock except with LayerNorms on the residual backbone of the block +# Modified from diffusers.models.attention.BasicTransformerBlock +class UniDiffuserBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the + LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser + implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = False, + final_dropout: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + ff_output = self.ff(hidden_states) + + hidden_states = ff_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + return hidden_states + + +# Modified from diffusers.models.transformer_2d.Transformer2DModel +# Modify the transformer block structure to be U-Net like following U-ViT +# Only supports patch-style input and torch.nn.LayerNorm currently +# https://github.com/baofff/U-ViT +class UTransformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared + to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, + similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] + layer and then reshaped to (b, t, d). + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int | None = None, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: int | None = None, + attention_bias: bool = False, + sample_size: int | None = None, + num_vector_embeds: int | None = None, + patch_size: int | None = 2, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Input + # Only support patch input of shape (batch_size, num_channels, height, width) for now + assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." + + assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" + + # 2. Define input layers + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + + # 3. Define transformers blocks + # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, + # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in + # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). + # Quick hack to make the transformer block type configurable + if block_type == "unidiffuser": + block_cls = UniDiffuserBlock + else: + block_cls = UTransformerBlock + self.transformer_in_blocks = nn.ModuleList( + [ + block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + for d in range(num_layers // 2) + ] + ) + + self.transformer_mid_block = block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + + # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs + # before each transformer out_block. + self.transformer_out_blocks = nn.ModuleList( + [ + nn.ModuleDict( + { + "skip": SkipBlock( + inner_dim, + ), + "block": block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ), + } + ) + for d in range(num_layers // 2) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + + # Following the UniDiffuser U-ViT implementation, we process the transformer output with + # a LayerNorm layer with per-element affine params + self.norm_out = nn.LayerNorm(inner_dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + hidden_states_is_embedding: bool = False, + unpatchify: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): + Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will + ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the + transformer blocks. + unpatchify (`bool`, *optional*, defaults to `True`): + Whether to unpatchify the transformer output. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. Check inputs + + if not unpatchify and return_dict: + raise ValueError( + f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" + f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" + " rather than (batch_size, num_channels, height, width)." + ) + + # 1. Input + if not hidden_states_is_embedding: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + + # In ("downsample") blocks + skips = [] + for in_block in self.transformer_in_blocks: + hidden_states = in_block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + skips.append(hidden_states) + + # Mid block + hidden_states = self.transformer_mid_block(hidden_states) + + # Out ("upsample") blocks + for out_block in self.transformer_out_blocks: + hidden_states = out_block["skip"](hidden_states, skips.pop()) + hidden_states = out_block["block"]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic + hidden_states = self.norm_out(hidden_states) + # hidden_states = self.proj_out(hidden_states) + + if unpatchify: + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + else: + output = hidden_states + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UniDiffuserModel(ModelMixin, ConfigMixin): + """ + Transformer model for a image-text [UniDiffuser](https://huggingface.co/papers/2303.06555) model. This is a + modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the + CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). + + Parameters: + text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. + clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + ff_final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + use_data_type_embedding (`bool`, *optional*): + Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 + is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` + argument, which can either be `1` to use the weights trained on non-publically-available data or `0` + otherwise. This argument is subsequently embedded by the data type embedding, if used. + """ + + @register_to_config + def __init__( + self, + text_dim: int = 768, + clip_img_dim: int = 512, + num_text_tokens: int = 77, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int | None = None, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: int | None = None, + attention_bias: bool = False, + sample_size: int | None = None, + num_vector_embeds: int | None = None, + patch_size: int | None = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + use_timestep_embedding=False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = True, + use_data_type_embedding: bool = False, + ): + super().__init__() + + # 0. Handle dimensions + self.inner_dim = num_attention_heads * attention_head_dim + + assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.patch_size = patch_size + # Assume image is square... + self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) + + # 1. Define input layers + # 1.1 Input layers for text and image input + # For now, only support patch input for VAE latent image input + self.vae_img_in = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) + self.text_in = nn.Linear(text_dim, self.inner_dim) + + # 1.2. Timestep embeddings for t_img, t_text + self.timestep_img_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_img_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + self.timestep_text_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_text_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + # 1.3. Positional embedding + self.num_text_tokens = num_text_tokens + self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) + self.pos_embed_drop = nn.Dropout(p=dropout) + trunc_normal_(self.pos_embed, std=0.02) + + # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary + self.use_data_type_embedding = use_data_type_embedding + if self.use_data_type_embedding: + self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) + self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) + + # 2. Define transformer blocks + self.transformer = UTransformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + block_type=block_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + use_patch_pos_embed=use_patch_pos_embed, + ff_final_dropout=ff_final_dropout, + ) + + # 3. Define output layers + patch_dim = (patch_size**2) * out_channels + self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) + self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) + self.text_out = nn.Linear(self.inner_dim, text_dim) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def forward( + self, + latent_image_embeds: torch.Tensor, + image_embeds: torch.Tensor, + prompt_embeds: torch.Tensor, + timestep_img: torch.Tensor | float | int, + timestep_text: torch.Tensor | float | int, + data_type: torch.Tensor | float | int | None = 1, + encoder_hidden_states=None, + cross_attention_kwargs=None, + ): + """ + Args: + latent_image_embeds (`torch.Tensor` of shape `(batch size, latent channels, height, width)`): + Latent image representation from the VAE encoder. + image_embeds (`torch.Tensor` of shape `(batch size, 1, clip_img_dim)`): + CLIP-embedded image representation (unsqueezed in the first dimension). + prompt_embeds (`torch.Tensor` of shape `(batch size, seq_len, text_dim)`): + CLIP-embedded text representation. + timestep_img (`torch.long` or `float` or `int`): + Current denoising step for the image. + timestep_text (`torch.long` or `float` or `int`): + Current denoising step for the text. + data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): + Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, + or `0` otherwise. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + + + Returns: + `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE + image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text + embedding. + """ + batch_size = latent_image_embeds.shape[0] + + # 1. Input + # 1.1. Map inputs to shape (B, N, inner_dim) + vae_hidden_states = self.vae_img_in(latent_image_embeds) + clip_hidden_states = self.clip_img_in(image_embeds) + text_hidden_states = self.text_in(prompt_embeds) + + num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) + + # 1.2. Encode image timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_img): + timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) + + timestep_img_token = self.timestep_img_proj(timestep_img) + # t_img_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_img_token = timestep_img_token.to(dtype=self.dtype) + timestep_img_token = self.timestep_img_embed(timestep_img_token) + timestep_img_token = timestep_img_token.unsqueeze(dim=1) + + # 1.3. Encode text timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_text): + timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) + + timestep_text_token = self.timestep_text_proj(timestep_text) + # t_text_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_text_token = timestep_text_token.to(dtype=self.dtype) + timestep_text_token = self.timestep_text_embed(timestep_text_token) + timestep_text_token = timestep_text_token.unsqueeze(dim=1) + + # 1.4. Concatenate all of the embeddings together. + if self.use_data_type_embedding: + assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" + if not torch.is_tensor(data_type): + data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) + + data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) + hidden_states = torch.cat( + [ + timestep_img_token, + timestep_text_token, + data_type_token, + text_hidden_states, + clip_hidden_states, + vae_hidden_states, + ], + dim=1, + ) + else: + hidden_states = torch.cat( + [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], + dim=1, + ) + + # 1.5. Prepare the positional embeddings and add to hidden states + # Note: I think img_vae should always have the proper shape, so there's no need to interpolate + # the position embeddings. + if self.use_data_type_embedding: + pos_embed = torch.cat( + [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 + ) + else: + pos_embed = self.pos_embed + hidden_states = hidden_states + pos_embed + hidden_states = self.pos_embed_drop(hidden_states) + + # 2. Blocks + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=None, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + hidden_states_is_embedding=True, + unpatchify=False, + )[0] + + # 3. Output + # Split out the predicted noise representation. + if self.use_data_type_embedding: + ( + t_img_token_out, + t_text_token_out, + data_type_token_out, + text_out, + img_clip_out, + img_vae_out, + ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) + else: + t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( + (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 + ) + + img_vae_out = self.vae_img_out(img_vae_out) + + # unpatchify + height = width = int(img_vae_out.shape[1] ** 0.5) + img_vae_out = img_vae_out.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) + img_vae_out = img_vae_out.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + img_clip_out = self.clip_img_out(img_clip_out) + + text_out = self.text_out(text_out) + + return img_vae_out, img_clip_out, text_out diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py new file mode 100644 index 0000000000000000000000000000000000000000..81d2ce95dc5348fcad141c5d6f3a7cf79815485e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -0,0 +1,1462 @@ +import inspect +from dataclasses import dataclass +from typing import Callable + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.outputs import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline +from .modeling_text_decoder import UniDiffuserTextDecoder +from .modeling_uvit import UniDiffuserModel + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# New BaseOutput child class for joint image-text output +@dataclass +class ImageTextPipelineOutput(BaseOutput): + """ + Output class for joint image-text pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + list of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + text (`list[str]` or `list[list[str]]`) + list of generated text strings of length `batch_size` or a list of list of strings whose outer list has + length `batch_size`. + """ + + images: list[PIL.Image.Image] | np.ndarray | None + text: list[str] | list[list[str]] | None + + +class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + r""" + Pipeline for a bimodal image-text model which supports unconditional text and image generation, text-conditioned + image generation, image-conditioned text generation, and joint image-text generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. This + is part of the UniDiffuser image representation along with the CLIP vision encoding. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + image_encoder ([`CLIPVisionModel`]): + A [`~transformers.CLIPVisionModel`] to encode images as part of its image representation along with the VAE + latent representation. + image_processor ([`CLIPImageProcessor`]): + [`~transformers.CLIPImageProcessor`] to preprocess an image before CLIP encoding it with `image_encoder`. + clip_tokenizer ([`CLIPTokenizer`]): + A [`~transformers.CLIPTokenizer`] to tokenize the prompt before encoding it with `text_encoder`. + text_decoder ([`UniDiffuserTextDecoder`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + text_tokenizer ([`GPT2Tokenizer`]): + A [`~transformers.GPT2Tokenizer`] to decode text for text generation; used along with the `text_decoder`. + unet ([`UniDiffuserModel`]): + A [U-ViT](https://github.com/baofff/U-ViT) model with UNNet-style skip connections between transformer + layers to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The + original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. + """ + + _last_supported_version = "0.33.1" + # TODO: support for moving submodules for components with enable_model_cpu_offload + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae->text_decoder" + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModelWithProjection, + clip_image_processor: CLIPImageProcessor, + clip_tokenizer: CLIPTokenizer, + text_decoder: UniDiffuserTextDecoder, + text_tokenizer: GPT2Tokenizer, + unet: UniDiffuserModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: + raise ValueError( + f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" + f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + clip_image_processor=clip_image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.num_channels_latents = vae.config.latent_channels + self.text_encoder_seq_len = text_encoder.config.max_position_embeddings + self.text_encoder_hidden_size = text_encoder.config.hidden_size + self.image_encoder_projection_dim = image_encoder.config.projection_dim + self.unet_resolution = unet.config.sample_size + + self.text_intermediate_dim = self.text_encoder_hidden_size + if self.text_decoder.prefix_hidden_dim is not None: + self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim + + self.mode = None + + # TODO: handle safety checking? + self.safety_checker = None + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): + r""" + Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set + mode will be used. + """ + prompt_available = (prompt is not None) or (prompt_embeds is not None) + image_available = image is not None + input_available = prompt_available or image_available + + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + full_latents_available = latents is not None + image_latents_available = vae_latents_available and clip_latents_available + all_indv_latents_available = prompt_latents_available and image_latents_available + + if self.mode is not None: + # Preferentially use the mode set by the user + mode = self.mode + elif prompt_available: + mode = "text2img" + elif image_available: + mode = "img2text" + else: + # Neither prompt nor image supplied, infer based on availability of latents + if full_latents_available or all_indv_latents_available: + mode = "joint" + elif prompt_latents_available: + mode = "text" + elif image_latents_available: + mode = "img" + else: + # No inputs or latents available + mode = "joint" + + # Give warnings for ambiguous cases + if self.mode is None and prompt_available and image_available: + logger.warning( + f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," + f" defaulting to mode '{mode}'." + ) + + if self.mode is None and not input_available: + if vae_latents_available != clip_latents_available: + # Exactly one of vae_latents and clip_latents is supplied + logger.warning( + f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" + f" are expected to be supplied. Defaulting to mode '{mode}'." + ) + elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: + # No inputs or latents supplied + logger.warning( + f"No inputs or latents have been supplied, and mode has not been manually set," + f" defaulting to mode '{mode}'." + ) + + return mode + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Functions to manually set the mode + def set_text_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") text generation.""" + self.mode = "text" + + def set_image_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") image generation.""" + self.mode = "img" + + def set_text_to_image_mode(self): + r"""Manually set the generation mode to text-conditioned image generation.""" + self.mode = "text2img" + + def set_image_to_text_mode(self): + r"""Manually set the generation mode to image-conditioned text generation.""" + self.mode = "img2text" + + def set_joint_mode(self): + r"""Manually set the generation mode to unconditional joint image-text generation.""" + self.mode = "joint" + + def reset_mode(self): + r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" + self.mode = None + + def _infer_batch_size( + self, + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ): + r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" + if num_images_per_prompt is None: + num_images_per_prompt = 1 + if num_prompts_per_image is None: + num_prompts_per_image = 1 + + assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" + assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" + + if mode in ["text2img"]: + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + # Either prompt or prompt_embeds must be present for text2img. + batch_size = prompt_embeds.shape[0] + multiplier = num_images_per_prompt + elif mode in ["img2text"]: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + else: + # Image must be available and type either PIL.Image.Image or torch.Tensor. + # Not currently supporting something like image_embeds. + batch_size = image.shape[0] + multiplier = num_prompts_per_image + elif mode in ["img"]: + if vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + multiplier = num_images_per_prompt + elif mode in ["text"]: + if prompt_latents is not None: + batch_size = prompt_latents.shape[0] + else: + batch_size = 1 + multiplier = num_prompts_per_image + elif mode in ["joint"]: + if latents is not None: + batch_size = latents.shape[0] + elif prompt_latents is not None: + batch_size = prompt_latents.shape[0] + elif vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + + if num_images_per_prompt == num_prompts_per_image: + multiplier = num_images_per_prompt + else: + multiplier = min(num_images_per_prompt, num_prompts_per_image) + logger.warning( + f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" + f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" + f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." + ) + return batch_size, multiplier + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with self.tokenizer->self.clip_tokenizer + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + lora_scale: float | None = None, + clip_skip: int | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.clip_tokenizer) + + text_inputs = self.clip_tokenizer( + prompt, + padding="max_length", + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.clip_tokenizer.batch_decode( + untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.clip_tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.clip_tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + # Add num_prompts_per_image argument, sample from autoencoder moment distribution + def encode_image_vae_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + do_classifier_free_guidance, + generator=None, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + * self.vae.config.scaling_factor + for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + # Scale image_latents by the VAE's scaling factor + image_latents = image_latents * self.vae.config.scaling_factor + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def encode_image_clip_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + generator=None, + ): + # Map image to CLIP embedding. + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + preprocessed_image = self.clip_image_processor.preprocess( + image, + return_tensors="pt", + ) + preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list): + image_latents = [ + self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.image_encoder(**preprocessed_image).image_embeds + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return image_latents + + def prepare_text_latents( + self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded prompt. + shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shace (B, L, D) + latents = latents.repeat(num_images_per_prompt, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. + def prepare_image_vae_latents( + self, + batch_size, + num_prompts_per_image, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size * num_prompts_per_image, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, C, H, W) + latents = latents.repeat(num_prompts_per_image, 1, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_clip_latents( + self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded image. + shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, L, D) + latents = latents.repeat(num_prompts_per_image, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_text_latents(self, text_latents, device): + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + generated_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + return generated_text + + def _split(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) + and (B, 1, clip_img_dim) + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + + img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + return img_vae, img_clip + + def _combine(self, img_vae, img_clip): + r""" + Combines a latent image img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, + clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + return torch.concat([img_vae, img_clip], dim=-1) + + def _split_joint(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, + img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is + of shape (B, text_seq_len, text_dim). + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_intermediate_dim + + img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) + return img_vae, img_clip, text + + def _combine_joint(self, img_vae, img_clip, text): + r""" + Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, + clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, + C * H * W + L_img * clip_img_dim + L_text * text_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + text = torch.reshape(text, (text.shape[0], -1)) + return torch.concat([img_vae, img_clip, text], dim=-1) + + def _get_noise_pred( + self, + mode, + latents, + t, + prompt_embeds, + img_vae, + img_clip, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ): + r""" + Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. + """ + if mode == "joint": + # Joint text-image generation + img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type + ) + + x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) + + if guidance_scale <= 1.0: + return x_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + _, _, text_out_uncond = self.unet( + img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) + + return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond + elif mode == "text2img": + # Text-conditioned image generation + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type + ) + + img_out = self._combine(img_vae_out, img_clip_out) + + if guidance_scale <= 1.0: + return img_out + + # Classifier-free guidance + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) + + return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond + elif mode == "img2text": + # Image-conditioned text generation + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type + ) + + if guidance_scale <= 1.0: + return text_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond + elif mode == "text": + # Unconditional ("marginal") text generation (no CFG) + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return text_out + elif mode == "img": + # Unconditional ("marginal") image generation (no CFG) + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, + img_clip_latents, + prompt_embeds, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out = self._combine(img_vae_out, img_clip_out) + return img_out + + def check_latents_shape(self, latents_name, latents, expected_shape): + latents_shape = latents.shape + expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension + expected_shape_str = ", ".join(str(dim) for dim in expected_shape) + if len(latents_shape) != expected_num_dims: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {len(latents_shape)} dimensions." + ) + for i in range(1, expected_num_dims): + if latents_shape[i] != expected_shape[i - 1]: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." + ) + + def check_inputs( + self, + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + latents=None, + prompt_latents=None, + vae_latents=None, + clip_latents=None, + ): + # Check inputs before running the generative process. + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if mode == "text2img": + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if mode == "img2text": + if image is None: + raise ValueError("`img2text` mode requires an image to be provided.") + + # Check provided latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + full_latents_available = latents is not None + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + + if full_latents_available: + individual_latents_available = ( + prompt_latents is not None or vae_latents is not None or clip_latents is not None + ) + if individual_latents_available: + logger.warning( + "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" + " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." + ) + # Check shape of full latents + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim + latents_expected_shape = (latents_dim,) + self.check_latents_shape("latents", latents, latents_expected_shape) + + # Check individual latent shapes, if present + if prompt_latents_available: + prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) + self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) + + if vae_latents_available: + vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) + self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) + + if clip_latents_available: + clip_latents_expected_shape = (1, self.image_encoder_projection_dim) + self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) + + if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: + if vae_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" + f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." + ) + + if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: + if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" + f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" + f" != {clip_latents.shape[0]}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str] | None = None, + image: torch.Tensor | PIL.Image.Image | None = None, + height: int | None = None, + width: int | None = None, + data_type: int | None = 1, + num_inference_steps: int = 50, + guidance_scale: float = 8.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + num_prompts_per_image: int | None = 1, + eta: float = 0.0, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_latents: torch.Tensor | None = None, + vae_latents: torch.Tensor | None = None, + clip_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback: Callable[[int, int, torch.Tensor], None] | None = None, + callback_steps: int = 1, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + Required for text-conditioned image generation (`text2img`) mode. + image (`torch.Tensor` or `PIL.Image.Image`, *optional*): + `Image` or tensor representing an image batch. Required for image-conditioned text generation + (`img2text`) mode. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + data_type (`int`, *optional*, defaults to 1): + The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type + embedding; this is added for compatibility with the + [UniDiffuser-v1](https://huggingface.co/thu-ml/unidiffuser-v1) checkpoint. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). Used in + text-conditioned image generation (`text2img`) mode. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and + `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated. + num_prompts_per_image (`int`, *optional*, defaults to 1): + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and + `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only + applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for joint + image-text generation. Can be used to tweak the same generation with different prompts. If not + provided, a latents tensor is generated by sampling using the supplied random `generator`. This assumes + a full set of VAE, CLIP, and text latents, if supplied, overrides the value of `prompt_latents`, + `vae_latents`, and `clip_latents`. + prompt_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for text + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + vae_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + clip_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. Used in text-conditioned + image generation (`text2img`) mode. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are be generated from the `negative_prompt` input argument. Used + in text-conditioned image generation (`text2img`) mode. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImageTextPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + + Returns: + [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.unidiffuser.ImageTextPipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images and the second element + is a list of generated texts. + """ + + # 0. Default height and width to unet + height = height or self.unet_resolution * self.vae_scale_factor + width = width or self.unet_resolution * self.vae_scale_factor + + # 1. Check inputs + # Recalculate mode for each call to the pipeline. + mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) + self.check_inputs( + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + + # 2. Define call parameters + batch_size, multiplier = self._infer_batch_size( + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + device = self._execution_device + reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Note that this differs from the formulation in the unidiffusers paper! + do_classifier_free_guidance = guidance_scale > 1.0 + + # check if scheduler is in sigmas space + # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt, if available; otherwise prepare text latents + if latents is not None: + # Overwrite individual latents + vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) + + if mode in ["text2img"]: + # 3.1. Encode input prompt, if available + assert prompt is not None or prompt_embeds is not None + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=multiplier, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # if do_classifier_free_guidance: + # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + else: + # 3.2. Prepare text latent variables, if input not available + prompt_embeds = self.prepare_text_latents( + batch_size=batch_size, + num_images_per_prompt=multiplier, + seq_len=self.text_encoder_seq_len, + hidden_size=self.text_encoder_hidden_size, + dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision + device=device, + generator=generator, + latents=prompt_latents, + ) + + if reduce_text_emb_dim: + prompt_embeds = self.text_decoder.encode(prompt_embeds) + + # 4. Encode image, if available; otherwise prepare image latents + if mode in ["img2text"]: + # 4.1. Encode images, if available + assert image is not None, "`img2text` requires a conditioning image" + # Encode image using VAE + image_vae = self.image_processor.preprocess(image) + height, width = image_vae.shape[-2:] + image_vae_latents = self.encode_image_vae_latents( + image=image_vae, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG + generator=generator, + ) + + # Encode image using CLIP + image_clip_latents = self.encode_image_clip_latents( + image=image, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) + image_clip_latents = image_clip_latents.unsqueeze(1) + else: + # 4.2. Prepare image latent variables, if input not available + # Prepare image VAE latents in latent space + image_vae_latents = self.prepare_image_vae_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=vae_latents, + ) + + # Prepare image CLIP latents + image_clip_latents = self.prepare_image_clip_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + clip_img_dim=self.image_encoder_projection_dim, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=clip_latents, + ) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + # max_timestep = timesteps[0] + max_timestep = self.scheduler.config.num_train_timesteps + + # 6. Prepare latent variables + if mode == "joint": + latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) + elif mode in ["text2img", "img"]: + latents = self._combine(image_vae_latents, image_clip_latents) + elif mode in ["img2text", "text"]: + latents = prompt_embeds + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # predict the noise residual + # Also applies classifier-free guidance as described in the UniDiffuser paper + noise_pred = self._get_noise_pred( + mode, + latents, + t, + prompt_embeds, + image_vae_latents, + image_clip_latents, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 9. Post-processing + image = None + text = None + if mode == "joint": + image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) + + if not output_type == "latent": + # Map latent VAE image back to pixel space + image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = image_vae_latents + + text = self.decode_text_latents(text_latents, device) + elif mode in ["text2img", "img"]: + image_vae_latents, image_clip_latents = self._split(latents, height, width) + + if not output_type == "latent": + # Map latent VAE image back to pixel space + image = self.vae.decode(image_vae_latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = image_vae_latents + elif mode in ["img2text", "text"]: + text_latents = latents + text = self.decode_text_latents(text_latents, device) + + self.maybe_free_model_hooks() + + # 10. Postprocess the image, if necessary + if image is not None: + do_denormalize = [True] * image.shape[0] + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, text) + + return ImageTextPipelineOutput(images=image, text=text) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab765a1bbad9ae0b8c7f2e7dec5019e3a5e4ea01 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/__init__.py @@ -0,0 +1,52 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_visualcloze_combined"] = ["VisualClozePipeline"] + _import_structure["pipeline_visualcloze_generation"] = ["VisualClozeGenerationPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_visualcloze_combined import VisualClozePipeline + from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..f640fddc2bc5cb12b545b0754b18ead903bc963e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py @@ -0,0 +1,440 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import torch +from PIL import Image +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline +from ..flux.pipeline_output import FluxPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline + + +if is_torch_xla_available(): + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import VisualClozePipeline + >>> from diffusers.utils import load_image + + >>> image_paths = [ + ... # in-context examples + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg" + ... ), + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg" + ... ), + ... ], + ... # query with the target image + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg" + ... ), + ... None, # No image needed for the target image + ... ], + ... ] + >>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." + >>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography." + >>> pipe = VisualClozePipeline.from_pretrained( + ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... task_prompt=task_prompt, + ... content_prompt=content_prompt, + ... image=image_paths, + ... upsampling_width=1344, + ... upsampling_height=768, + ... upsampling_strength=0.4, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0][0] + >>> image.save("visualcloze.png") + ``` +""" + + +class VisualClozePipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The VisualCloze pipeline for image generation with visual context. Reference: + https://github.com/lzyhha/VisualCloze/tree/main. This pipeline is designed to generate images based on visual + in-context examples. + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + resolution (`int`, *optional*, defaults to 384): + The resolution of each image when concatenating images from the query and in-context examples. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + resolution: int = 384, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + + self.generation_pipe = VisualClozeGenerationPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + resolution=resolution, + ) + self.upsampling_pipe = VisualClozeUpsamplingPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + + def check_inputs( + self, + image, + task_prompt, + content_prompt, + upsampling_height, + upsampling_width, + strength, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly" + ) + if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Validate prompt inputs + if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None: + raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ") + + if task_prompt is None and content_prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ") + + # Validate prompt types and consistency + if task_prompt is None: + raise ValueError("`task_prompt` is missing.") + + if task_prompt is not None and not isinstance(task_prompt, (str, list)): + raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}") + + if content_prompt is not None and not isinstance(content_prompt, (str, list)): + raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}") + + if isinstance(task_prompt, list) or isinstance(content_prompt, list): + if not isinstance(task_prompt, list) or not isinstance(content_prompt, list): + raise ValueError( + f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, " + f"got {type(task_prompt)} and {type(content_prompt)}" + ) + if len(content_prompt) != len(task_prompt): + raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.") + + for sample in image: + if not isinstance(sample, list) or not isinstance(sample[0], list): + raise ValueError("Each sample in the batch must have a 2D list of images.") + if len({len(row) for row in sample}) != 1: + raise ValueError("Each in-context example and query should contain the same number of images.") + if not any(img is None for img in sample[-1]): + raise ValueError("There are no targets in the query, which should be represented as None.") + for row in sample[:-1]: + if any(img is None for img in row): + raise ValueError("Images are missing in in-context examples.") + + # Validate embeddings + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + # Validate sequence length + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + task_prompt: str | list[str] = None, + content_prompt: str | list[str] = None, + image: torch.FloatTensor | None = None, + upsampling_height: int | None = None, + upsampling_width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 30.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + upsampling_strength: float = 1.0, + ): + r""" + Function invoked when calling the VisualCloze pipeline for generation. + + Args: + task_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By + default, the image is upsampled by a factor of three, and the base resolution is determined by the + resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is + specified, the other will be automatically set based on the aspect ratio. + upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By + default, the image is upsampled by a factor of three, and the base resolution is determined by the + resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is + specified, the other will be automatically set based on the aspect ratio. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 30.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + upsampling_strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and + 1. The generated image is used as a starting point and more noise is added the higher the + `upsampling_strength`. The number of denoising steps depends on the amount of noise initially added. + When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full + number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and + output the results at the resolution of `self.resolution`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + generation_output = self.generation_pipe( + task_prompt=task_prompt, + content_prompt=content_prompt, + image=image, + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + joint_attention_kwargs=joint_attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + output_type=output_type if upsampling_strength == 0 else "pil", + ) + if upsampling_strength == 0: + if not return_dict: + return (generation_output,) + + return FluxPipelineOutput(images=generation_output) + + # Upsampling the generated images + # 1. Prepare the input images and prompts + if not isinstance(content_prompt, (list)): + content_prompt = [content_prompt] + n_target_per_sample = [] + upsampling_image = [] + upsampling_mask = [] + upsampling_prompt = [] + upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else [] + for i in range(len(generation_output.images)): + n_target_per_sample.append(len(generation_output.images[i])) + for image in generation_output.images[i]: + upsampling_image.append(image) + upsampling_mask.append(Image.new("RGB", image.size, (255, 255, 255))) + upsampling_prompt.append( + content_prompt[i % len(content_prompt)] if content_prompt[i % len(content_prompt)] else "" + ) + if not isinstance(generator, (torch.Generator,)): + upsampling_generator.append(generator[i % len(content_prompt)]) + + # 2. Apply the denosing loop + upsampling_output = self.upsampling_pipe( + prompt=upsampling_prompt, + image=upsampling_image, + mask_image=upsampling_mask, + height=upsampling_height, + width=upsampling_width, + strength=upsampling_strength, + num_inference_steps=num_inference_steps, + sigmas=sigmas, + guidance_scale=guidance_scale, + generator=upsampling_generator, + output_type=output_type, + joint_attention_kwargs=joint_attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + image = upsampling_output.images + + output = [] + if output_type == "pil": + # Each sample in the batch may have multiple output images. When returning as PIL images, + # these images cannot be concatenated. Therefore, for each sample, + # a list is used to represent all the output images. + output = [] + start = 0 + for n in n_target_per_sample: + output.append(image[start : start + n]) + start += n + else: + output = image + + if not return_dict: + return (output,) + + return FluxPipelineOutput(images=output) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5d0603d6d063e0199b07fa8f8118c6e5138113 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py @@ -0,0 +1,977 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import FluxTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..flux.pipeline_flux_fill import calculate_shift, retrieve_latents, retrieve_timesteps +from ..flux.pipeline_output import FluxPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .visualcloze_utils import VisualClozeProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline + >>> from diffusers.utils import load_image + >>> from PIL import Image + + >>> image_paths = [ + ... # in-context examples + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg" + ... ), + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg" + ... ), + ... ], + ... # query with the target image + ... [ + ... load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg" + ... ), + ... None, # No image needed for the target image + ... ], + ... ] + >>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding." + >>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography." + >>> pipe = VisualClozeGenerationPipeline.from_pretrained( + ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = pipe( + ... task_prompt=task_prompt, + ... content_prompt=content_prompt, + ... image=image_paths, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0][0] + + >>> # optional, upsampling the generated image + >>> pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe) + >>> pipe_upsample.to("cuda") + + >>> mask_image = Image.new("RGB", image.size, (255, 255, 255)) + + >>> image = pipe_upsample( + ... image=image, + ... mask_image=mask_image, + ... prompt=content_prompt, + ... width=1344, + ... height=768, + ... strength=0.4, + ... guidance_scale=30, + ... num_inference_steps=30, + ... max_sequence_length=512, + ... generator=torch.Generator("cpu").manual_seed(0), + ... ).images[0] + + >>> image.save("visualcloze.png") + ``` +""" + + +class VisualClozeGenerationPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + The VisualCloze pipeline for image generation with visual context. Reference: + https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual + in-context examples. + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + resolution (`int`, *optional*, defaults to 384): + The resolution of each image when concatenating images from the query and in-context examples. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + resolution: int = 384, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + ) + self.resolution = resolution + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VisualClozeProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution + ) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds + def _get_clip_prompt_embeds( + self, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + # Modified from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt + def encode_prompt( + self, + layout_prompt: str | list[str], + task_prompt: str | list[str], + content_prompt: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ): + r""" + + Args: + layout_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the number of in-context examples and the number of images involved in + the task. + task_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + if isinstance(layout_prompt, str): + layout_prompt = [layout_prompt] + task_prompt = [task_prompt] + content_prompt = [content_prompt] + + def _preprocess(prompt, content=False): + if prompt is not None: + return f"The last image of the last row depicts: {prompt}" if content else prompt + else: + return "" + + prompt = [ + f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip() + for i in range(len(layout_prompt)) + ] + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + image, + task_prompt, + content_prompt, + prompt_embeds=None, + pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Validate prompt inputs + if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None: + raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ") + + if task_prompt is None and content_prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ") + + # Validate prompt types and consistency + if task_prompt is None: + raise ValueError("`task_prompt` is missing.") + + if task_prompt is not None and not isinstance(task_prompt, (str, list)): + raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}") + + if content_prompt is not None and not isinstance(content_prompt, (str, list)): + raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}") + + if isinstance(task_prompt, list) or isinstance(content_prompt, list): + if not isinstance(task_prompt, list) or not isinstance(content_prompt, list): + raise ValueError( + f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, " + f"got {type(task_prompt)} and {type(content_prompt)}" + ) + if len(content_prompt) != len(task_prompt): + raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.") + + for sample in image: + if not isinstance(sample, list) or not isinstance(sample[0], list): + raise ValueError("Each sample in the batch must have a 2D list of images.") + if len({len(row) for row in sample}) != 1: + raise ValueError("Each in-context example and query should contain the same number of images.") + if not any(img is None for img in sample[-1]): + raise ValueError("There are no targets in the query, which should be represented as None.") + for row in sample[:-1]: + if any(img is None for img in row): + raise ValueError("Images are missing in in-context examples.") + + # Validate embeddings + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + # Validate sequence length + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype): + latent_image_ids = [] + + for idx, img in enumerate(image, start=1): + img = img.squeeze(0) + channels, height, width = img.shape + + num_patches_h = height // vae_scale_factor // 2 + num_patches_w = width // vae_scale_factor // 2 + + patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype) + patch_ids[..., 0] = idx + patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None] + patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :] + + patch_ids = patch_ids.reshape(-1, 3) + latent_image_ids.append(patch_ids) + + return torch.cat(latent_image_ids, dim=0) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, sizes, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + start = 0 + unpacked_latents = [] + for i in range(len(sizes)): + cur_size = sizes[i] + height = cur_size[0][0] // vae_scale_factor + width = sum([size[1] for size in cur_size]) // vae_scale_factor + + end = start + (height * width) // 4 + + cur_latents = latents[:, start:end] + cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5) + cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width) + + unpacked_latents.append(cur_latents) + + start = end + + return unpacked_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype): + """Helper function to prepare latents for a single batch.""" + # Concatenate images and masks along width dimension + image = [torch.cat(img, dim=3).to(device=device, dtype=dtype) for img in image] + mask = [torch.cat(m, dim=3).to(device=device, dtype=dtype) for m in mask] + + # Generate latent image IDs + latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype) + + # For initial encoding, use actual images + image_latent = [self._encode_vae_image(img, gen) for img in image] + masked_image_latent = [img.clone() for img in image_latent] + + for i in range(len(image_latent)): + # Rearrange latents and masks for patch processing + num_channels_latents, height, width = image_latent[i].shape[1:] + image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width) + masked_image_latent[i] = self._pack_latents(masked_image_latent[i], 1, num_channels_latents, height, width) + + # Rearrange masks for patch processing + num_channels_latents, height, width = mask[i].shape[1:] + mask[i] = mask[i].view( + 1, + num_channels_latents, + height // vae_scale_factor, + vae_scale_factor, + width // vae_scale_factor, + vae_scale_factor, + ) + mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4) + mask[i] = mask[i].reshape( + 1, + num_channels_latents * (vae_scale_factor**2), + height // vae_scale_factor, + width // vae_scale_factor, + ) + mask[i] = self._pack_latents( + mask[i], + 1, + num_channels_latents * (vae_scale_factor**2), + height // vae_scale_factor, + width // vae_scale_factor, + ) + + # Concatenate along batch dimension + image_latent = torch.cat(image_latent, dim=1) + masked_image_latent = torch.cat(masked_image_latent, dim=1) + mask = torch.cat(mask, dim=1) + + return image_latent, masked_image_latent, mask, latent_image_ids + + def prepare_latents( + self, + input_image, + input_mask, + timestep, + batch_size, + dtype, + device, + generator, + vae_scale_factor, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Process each batch + masked_image_latents = [] + image_latents = [] + masks = [] + latent_image_ids = [] + + for i in range(len(input_image)): + _image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents( + input_image[i], + input_mask[i], + generator if isinstance(generator, torch.Generator) else generator[i], + vae_scale_factor, + device, + dtype, + ) + masked_image_latents.append(_masked_image_latent) + image_latents.append(_image_latent) + masks.append(_mask) + latent_image_ids.append(_latent_image_ids) + + # Concatenate all batches + masked_image_latents = torch.cat(masked_image_latents, dim=0) + image_latents = torch.cat(image_latents, dim=0) + masks = torch.cat(masks, dim=0) + + # Handle batch size expansion + if batch_size > masked_image_latents.shape[0]: + if batch_size % masked_image_latents.shape[0] == 0: + # Expand batches by repeating + additional_image_per_prompt = batch_size // masked_image_latents.shape[0] + masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0) + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + masks = torch.cat([masks] * additional_image_per_prompt, dim=0) + else: + raise ValueError( + f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. " + "Batch sizes must be multiples of each other." + ) + + # Add noise to latents + noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype) + + # Combine masked latents with masks + masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype) + + return latents, masked_image_latents, latent_image_ids[0] + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + task_prompt: str | list[str] = None, + content_prompt: str | list[str] = None, + image: torch.FloatTensor | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 30.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the VisualCloze pipeline for generation. + + Args: + task_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the task intention. + content_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to define the content or caption of the target image to be generated. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 30.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + image, + task_prompt, + content_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + processor_output = self.image_processor.preprocess( + task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor + ) + + # 2. Define call parameters + if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str): + batch_size = 1 + elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list): + batch_size = len(processor_output["task_prompt"]) + + device = self._execution_device + + # 3. Prepare prompt embeddings + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( + layout_prompt=processor_output["layout_prompt"], + task_prompt=processor_output["task_prompt"], + content_prompt=processor_output["content_prompt"], + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + # Calculate sequence length and shift factor + image_seq_len = sum( + (size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2) + for sample in processor_output["image_size"][0] + for size in sample + ) + + # Calculate noise schedule parameters + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + # Get timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device) + + # 5. Prepare latent variables + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents, masked_image_latents, latent_image_ids = self.prepare_latents( + processor_output["init_image"], + processor_output["mask"], + latent_timestep, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + vae_scale_factor=self.vae_scale_factor, + ) + + # Calculate warmup steps + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Prepare guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + latent_model_input = torch.cat((latents, masked_image_latents), dim=2) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + # Compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # XLA optimization + if XLA_AVAILABLE: + xm.mark_step() + + # 7. Post-process the image + # Crop the target image + # Since the generated image is a concatenation of the conditional and target regions, + # we need to extract only the target regions based on their positions + image = [] + if output_type == "latent": + image = latents + else: + for b in range(len(latents)): + cur_image_size = processor_output["image_size"][b % batch_size] + cur_target_position = processor_output["target_position"][b % batch_size] + cur_latent = self._unpack_latents(latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1] + cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor + cur_image = self.vae.decode(cur_latent, return_dict=False)[0] + cur_image = self.image_processor.postprocess(cur_image, output_type=output_type)[0] + + start = 0 + cropped = [] + for i, size in enumerate(cur_image_size[-1]): + if cur_target_position[i]: + if output_type == "pil": + cropped.append(cur_image.crop((start, 0, start + size[1], size[0]))) + else: + cropped.append(cur_image[0 : size[0], start : start + size[1]]) + start += size[1] + image.append(cropped) + if output_type != "pil": + image = np.concatenate([arr[None] for sub_image in image for arr in sub_image], axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/visualcloze_utils.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/visualcloze_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a96ff8d2d1357440d49a15f88832802581b7b8d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/visualcloze/visualcloze_utils.py @@ -0,0 +1,250 @@ +# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from PIL import Image + +from ...image_processor import VaeImageProcessor + + +class VisualClozeProcessor(VaeImageProcessor): + """ + Image processor for the VisualCloze pipeline. + + This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and + mask generation. + + Args: + resolution (int, optional): + Target resolution for processing images. Each image will be resized to this resolution before being + concatenated to avoid the out-of-memory error. Defaults to 384. + *args: Additional arguments passed to [~image_processor.VaeImageProcessor] + **kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor] + """ + + def __init__(self, *args, resolution: int = 384, **kwargs): + super().__init__(*args, **kwargs) + self.resolution = resolution + + def preprocess_image( + self, input_images: list[list[Image.Image | None]], vae_scale_factor: int + ) -> tuple[list[list[torch.Tensor]], list[list[list[int]]], list[int]]: + """ + Preprocesses input images for the VisualCloze pipeline. + + This function handles the preprocessing of input images by: + 1. Resizing and cropping images to maintain consistent dimensions + 2. Converting images to the Tensor format for the VAE + 3. Normalizing pixel values + 4. Tracking image sizes and positions of target images + + Args: + input_images (list[list[Image.Image | None]]): + A nested list of PIL Images where: + - Outer list represents different samples, including in-context examples and the query + - Inner list contains images for the task + - In the last row, condition images are provided and the target images are placed as None + vae_scale_factor (int): + The scale factor used by the VAE for resizing images + + Returns: + tuple containing: + - list[list[torch.Tensor]]: Preprocessed images in tensor format + - list[list[list[int]]]: Dimensions of each processed image [height, width] + - list[int]: Target positions indicating which images are to be generated + """ + n_samples, n_task_images = len(input_images), len(input_images[0]) + divisible = 2 * vae_scale_factor + + processed_images: list[list[Image.Image]] = [[] for _ in range(n_samples)] + resize_size: list[tuple[int, int] | None] = [None for _ in range(n_samples)] + target_position: list[int] = [] + + # Process each sample + for i in range(n_samples): + # Determine size from first non-None image + for j in range(n_task_images): + if input_images[i][j] is not None: + aspect_ratio = input_images[i][j].width / input_images[i][j].height + target_area = self.resolution * self.resolution + new_h = int((target_area / aspect_ratio) ** 0.5) + new_w = int(new_h * aspect_ratio) + + new_w = max(new_w // divisible, 1) * divisible + new_h = max(new_h // divisible, 1) * divisible + resize_size[i] = (new_w, new_h) + break + + # Process all images in the sample + for j in range(n_task_images): + if input_images[i][j] is not None: + target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1]) + processed_images[i].append(target) + if i == n_samples - 1: + target_position.append(0) + else: + blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0)) + processed_images[i].append(blank) + if i == n_samples - 1: + target_position.append(1) + + # Ensure consistent width for multiple target images when there are multiple target images + if len(target_position) > 1 and sum(target_position) > 1: + new_w = resize_size[n_samples - 1][0] or 384 + for i in range(len(processed_images)): + for j in range(len(processed_images[i])): + if processed_images[i][j] is not None: + new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width)) + new_w = int(new_w / 16) * 16 + new_h = int(new_h / 16) * 16 + processed_images[i][j] = self._resize_and_crop(processed_images[i][j], new_h, new_w) + + # Convert to tensors and normalize + image_sizes = [] + for i in range(len(processed_images)): + image_sizes.append([[img.height, img.width] for img in processed_images[i]]) + for j, image in enumerate(processed_images[i]): + image = self.pil_to_numpy(image) + image = self.numpy_to_pt(image) + image = self.normalize(image) + processed_images[i][j] = image + + return processed_images, image_sizes, target_position + + def preprocess_mask( + self, input_images: list[list[Image.Image]], target_position: list[int] + ) -> list[list[torch.Tensor]]: + """ + Generate masks for the VisualCloze pipeline. + + Args: + input_images (list[list[Image.Image]]): + Processed images from preprocess_image + target_position (list[int]): + Binary list marking the positions of target images (1 for target, 0 for condition) + + Returns: + list[list[torch.Tensor]]: + A nested list of mask tensors (1 for target positions, 0 for condition images) + """ + mask = [] + for i, row in enumerate(input_images): + if i == len(input_images) - 1: # Query row + row_masks = [ + torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position + ] + else: # In-context examples + row_masks = [ + torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position + ] + mask.append(row_masks) + return mask + + def preprocess_image_upsampling( + self, + input_images: list[list[Image.Image]], + height: int, + width: int, + ) -> tuple[list[list[Image.Image]], list[list[list[int]]]]: + """Process images for the upsampling stage in the VisualCloze pipeline. + + Args: + input_images: Input image to process + height: Target height + width: Target width + + Returns: + tuple of processed image and its size + """ + image = self.resize(input_images[0][0], height, width) + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + image = self.normalize(image) + + input_images[0][0] = image + image_sizes = [[[height, width]]] + return input_images, image_sizes + + def preprocess_mask_upsampling(self, input_images: list[list[Image.Image]]) -> list[list[torch.Tensor]]: + return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]] + + def get_layout_prompt(self, size: tuple[int, int]) -> str: + layout_instruction = ( + f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.", + ) + return layout_instruction + + def preprocess( + self, + task_prompt: str | list[str], + content_prompt: str | list[str], + input_images: list[list[list[str | None]]] | None = None, + height: int | None = None, + width: int | None = None, + upsampling: bool = False, + vae_scale_factor: int = 16, + ) -> dict: + """Process visual cloze inputs. + + Args: + task_prompt: Task description(s) + content_prompt: Content description(s) + input_images: list of images or None for the target images + height: Optional target height for upsampling stage + width: Optional target width for upsampling stage + upsampling: Whether this is in the upsampling processing stage + + Returns: + dictionary containing processed images, masks, prompts and metadata + """ + if isinstance(task_prompt, str): + task_prompt = [task_prompt] + content_prompt = [content_prompt] + input_images = [input_images] + + output = { + "init_image": [], + "mask": [], + "task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))], + "content_prompt": content_prompt, + "layout_prompt": [], + "target_position": [], + "image_size": [], + } + for i in range(len(task_prompt)): + if upsampling: + layout_prompt = None + else: + layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0]))) + + if upsampling: + cur_processed_images, cur_image_size = self.preprocess_image_upsampling( + input_images[i], height=height, width=width + ) + cur_mask = self.preprocess_mask_upsampling(cur_processed_images) + else: + cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image( + input_images[i], vae_scale_factor=vae_scale_factor + ) + cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position) + + output["target_position"].append(cur_target_position) + + output["image_size"].append(cur_image_size) + output["init_image"].append(cur_processed_images) + output["mask"].append(cur_mask) + output["layout_prompt"].append(layout_prompt) + + return output diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad51a52f9242052d1424a360c59cac72f2d8129e --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_wan"] = ["WanPipeline"] + _import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"] + _import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"] + _import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"] + _import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_wan import WanPipeline + from .pipeline_wan_animate import WanAnimatePipeline + from .pipeline_wan_i2v import WanImageToVideoPipeline + from .pipeline_wan_vace import WanVACEPipeline + from .pipeline_wan_video2video import WanVideoToVideoPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b37fa2546350a53832bbf57bf5a701b00c820443 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/image_processor.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/image_processor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e80623ad4ee13748350dac3673ddcf1b163fce98 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/image_processor.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf2cc170224f1558183b6568b674e97ee2b1f55d Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6ccd9bbe3509ab80a93fd65097da90a8bfb5946 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_animate.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_animate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..969b3c5423034df28b04fc393aa79464041ad7c9 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_animate.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_i2v.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_i2v.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..998f2e0e109a579a4da6d83425041962d07d8b3c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_i2v.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_vace.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_vace.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1478f4684222939843a221ca3577ebddc0f436de Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_vace.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_video2video.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_video2video.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f591004326048dad2dd49f51874bd712e540362c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/__pycache__/pipeline_wan_video2video.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/image_processor.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..fa18150fcc6e66d4d986474e4dd3f68c13ba494c --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/image_processor.py @@ -0,0 +1,184 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import PIL.Image +import torch + +from ...configuration_utils import register_to_config +from ...image_processor import VaeImageProcessor +from ...utils import PIL_INTERPOLATION + + +class WanAnimateImageProcessor(VaeImageProcessor): + r""" + Image processor to preprocess the reference (character) image for the Wan Animate model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `16`): + VAE latent channels. + spatial_patch_size (`tuple[int, int]`, *optional*, defaults to `(2, 2)`): + The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2). + resample (`str`, *optional*, defaults to `lanczos`): + Resampling filter to use when resizing the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. + fill_color (`str` or `float` or `tuple[float, ...]`, *optional*, defaults to `None`): + An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that + color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid; + if `None`, will default to filling with data from `image`. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 16, + spatial_patch_size: tuple[int, int] = (2, 2), + resample: str = "lanczos", + reducing_gap: int = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + fill_color: str | float | tuple[float, ...] | None = 0, + ): + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + + def _resize_and_fill( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + r""" + Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, filling empty with data from image. + + Args: + image (`PIL.Image.Image`): + The image to resize and fill. + width (`int`): + The width to resize the image to. + height (`int`): + The height to resize the image to. + + Returns: + `PIL.Image.Image`: + The resized and filled image. + """ + + ratio = width / height + src_ratio = image.width / image.height + fill_with_image_data = self.config.fill_color is None + fill_color = self.config.fill_color or 0 + + src_w = width if ratio < src_ratio else image.width * height // image.height + src_h = height if ratio >= src_ratio else image.height * width // image.width + + resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample]) + res = PIL.Image.new("RGB", (width, height), color=fill_color) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if fill_with_image_data: + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste( + resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), + box=(0, fill_height + src_h), + ) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste( + resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), + box=(fill_width + src_w, 0), + ) + + return res + + def get_default_height_width( + self, + image: PIL.Image.Image | np.ndarray | torch.Tensor, + height: int | None = None, + width: int | None = None, + ) -> tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`PIL.Image.Image | np.ndarray | torch.Tensor`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`int | None`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`int | None`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor * spatial_patch_size`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + max_area = width * height + aspect_ratio = height / width + mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0] + mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1] + + # Try to preserve the aspect ratio + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w + + return height, width diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7f8a4f0ad9601786439718ffbd8e446c6549fb --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for Wan pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan.py new file mode 100644 index 0000000000000000000000000000000000000000..d4edff01ad667afaeef9a5e0b6b25d22a7720b25 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan.py @@ -0,0 +1,671 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, WanPipeline + >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + + >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers + >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=720, + ... width=1280, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + """ + + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: WanTransformer3DModel | None = None, + transformer_2: WanTransformer3DModel | None = None, + boundary_ratio: float | None = None, + expand_timesteps: bool = False, # Wan2.2 ti2v + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + transformer_2=transformer_2, + ) + self.register_to_config(boundary_ratio=boundary_ratio) + self.register_to_config(expand_timesteps=expand_timesteps) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + patch_size = ( + self.transformer.config.patch_size + if self.transformer is not None + else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " + f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = latents.to(transformer_dtype) + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_animate.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_animate.py new file mode 100644 index 0000000000000000000000000000000000000000..5806032c0142d6d04be3605a08bb665c4724b560 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_animate.py @@ -0,0 +1,1202 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from copy import deepcopy +from typing import Any, Callable + +import PIL +import regex as re +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import WanAnimateImageProcessor +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import WanAnimatePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers" + >>> pipe = WanAnimatePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> # Optionally upcast the Wan VAE to FP32 + >>> pipe.vae.to(torch.float32) + >>> pipe.to("cuda") + + >>> # Load the reference character image + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + + >>> # Load pose and face videos (preprocessed from reference video) + >>> # Note: Videos should be preprocessed to extract pose keypoints and face features + >>> # Refer to the Wan-Animate preprocessing documentation for details + >>> pose_video = load_video("path/to/pose_video.mp4") + >>> face_video = load_video("path/to/face_video.mp4") + + >>> # CFG is generally not used for Wan Animate + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + + >>> # Animation mode: Animate the character with the motion from pose/face videos + >>> output = pipe( + ... image=image, + ... pose_video=pose_video, + ... face_video=face_video, + ... prompt=prompt, + ... height=height, + ... width=width, + ... segment_frame_length=77, # Frame length of each inference segment + ... guidance_scale=1.0, + ... num_inference_steps=20, + ... mode="animate", + ... ).frames[0] + >>> export_to_video(output, "output_animation.mp4", fps=30) + + >>> # Replacement mode: Replace a character in the background video + >>> # Requires additional background_video and mask_video inputs + >>> background_video = load_video("path/to/background_video.mp4") + >>> mask_video = load_video("path/to/mask_video.mp4") # Black areas preserved, white areas generated + >>> output = pipe( + ... image=image, + ... pose_video=pose_video, + ... face_video=face_video, + ... background_video=background_video, + ... mask_video=mask_video, + ... prompt=prompt, + ... height=height, + ... width=width, + ... segment_frame_length=77, # Frame length of each inference segment + ... guidance_scale=1.0, + ... num_inference_steps=20, + ... mode="replace", + ... ).frames[0] + >>> export_to_video(output, "output_replacement.mp4", fps=30) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanAnimatePipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for unified character animation and replacement using Wan-Animate. + + WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in two + modes: + + 1. **Animation mode**: The model generates a video of the character image that mimics the human motion in the input + pose and face videos. The character is animated based on the provided motion controls, creating a new animated + video of the character. + + 2. **Replacement mode**: The model replaces a character in a background video with the provided character image, + using the pose and face videos for motion control. This mode requires additional `background_video` and + `mask_video` inputs. The mask video should have black regions where the original content should be preserved and + white regions where the new character should be generated. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.WanLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanAnimateTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + image_processor ([`CLIPImageProcessor`]): + Image processor for preprocessing images before encoding. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + transformer: WanAnimateTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor_for_mask = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True + ) + # In case self.transformer is None (e.g. for some pipeline tests) + spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2) + self.vae_image_processor = WanAnimateImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + spatial_patch_size=spatial_patch_size, + resample="bilinear", + fill_color=0, + ) + self.image_processor = image_processor + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: torch.device | None = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + mode=None, + prev_segment_conditioning_frames=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if pose_video is None: + raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") + if face_video is None: + raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.") + if not isinstance(pose_video, list) or not isinstance(face_video, list): + raise ValueError("`pose_video` and `face_video` must be lists of PIL images.") + if len(pose_video) == 0 or len(face_video) == 0: + raise ValueError("`pose_video` and `face_video` must contain at least one frame.") + if mode == "replace" and (background_video is None or mask_video is None): + raise ValueError( + "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`" + " undefined when mode is `replace`." + ) + if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)): + raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found" + f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")): + raise ValueError( + f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}" + ) + + if prev_segment_conditioning_frames is not None and ( + not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5) + ): + raise ValueError( + f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is" + f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}" + ) + + def get_i2v_mask( + self, + batch_size: int, + latent_t: int, + latent_h: int, + latent_w: int, + mask_len: int = 1, + mask_pixel_values: torch.Tensor | None = None, + dtype: torch.dtype | None = None, + device: str | torch.device = "cuda", + ) -> torch.Tensor: + # mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w] + if mask_pixel_values is None: + mask_lat_size = torch.zeros( + batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device + ) + else: + mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype) + mask_lat_size[:, :, :mask_len] = 1 + first_frame_mask = mask_lat_size[:, :, 0:1] + # Repeat first frame mask self.vae_scale_factor_temporal (= 4) times in the frame dimension + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w + ).transpose(1, 2) # [B, C = 1, 4 * T_lat, H_lat, W_lat] --> [B, C = 4, T_lat, H_lat, W_lat] + + return mask_lat_size + + def prepare_reference_image_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + sample_mode: int = "argmax", + generator: torch.Generator | list[torch.Generator] | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + # image shape: (B, C, H, W) or (B, C, T, H, W) + dtype = dtype or self.vae.dtype + if image.ndim == 4: + # Add a singleton frame dimension after the channels dimension + image = image.unsqueeze(2) + + _, _, _, height, width = image.shape + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + # Encode image to latents using VAE + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + # Like in prepare_latents, assume len(generator) == batch_size + ref_image_latents = [ + retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator + ] + ref_image_latents = torch.cat(ref_image_latents) + else: + ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode) + # Standardize latents in preparation for Wan VAE encode + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(ref_image_latents.device, ref_image_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + ref_image_latents.device, ref_image_latents.dtype + ) + ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std + # Handle the case where we supply one image and one generator, but batch_size > 1 (e.g. generating multiple + # videos per prompt) + if ref_image_latents.shape[0] == 1 and batch_size > 1: + ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1) + + # Prepare I2V mask in latent space and prepend to the reference image latents along channel dim + reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device) + reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1) + + return reference_image_latents + + def prepare_prev_segment_cond_latents( + self, + prev_segment_cond_video: torch.Tensor | None = None, + background_video: torch.Tensor | None = None, + mask_video: torch.Tensor | None = None, + batch_size: int = 1, + segment_frame_length: int = 77, + start_frame: int = 0, + height: int = 720, + width: int = 1280, + prev_segment_cond_frames: int = 1, + task: str = "animate", + interpolation_mode: str = "bicubic", + sample_mode: str = "argmax", + generator: torch.Generator | list[torch.Generator] | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + # prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied + # background_video shape: (B, C, T, H, W) (same as prev_segment_cond_video shape) + # mask_video shape: (B, 1, T, H, W) (same as prev_segment_cond_video, but with only 1 channel) + dtype = dtype or self.vae.dtype + if prev_segment_cond_video is None: + if task == "replace": + prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype) + else: + cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) # In pixel space + prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device) + + data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape + num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + if segment_height != height or segment_width != width: + print( + f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})" + ) + # Perform a 4D (spatial) rather than a 5D (spatiotemporal) reshape, following the original code + prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) # [B * T, C, H, W] + prev_segment_cond_video = F.interpolate( + prev_segment_cond_video, size=(height, width), mode=interpolation_mode + ) + prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2) + + # Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if + # replacing). + if task == "replace": + remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype) + else: + remaining_segment_frames = segment_frame_length - prev_segment_cond_frames + remaining_segment = torch.zeros( + batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device + ) + + # Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim + prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype) + full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2) + + if isinstance(generator, list): + if data_batch_size == len(generator): + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode) + for i, g in enumerate(generator) + ] + elif data_batch_size == 1: + # Like prepare_latents, assume len(generator) == batch_size + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator + ] + else: + raise ValueError( + f"The batch size of the prev segment video should be either {len(generator)} or 1 but is" + f" {data_batch_size}" + ) + prev_segment_cond_latents = torch.cat(prev_segment_cond_latents) + else: + prev_segment_cond_latents = retrieve_latents( + self.vae.encode(full_segment_cond_video), generator, sample_mode + ) + # Standardize latents in preparation for Wan VAE encode + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + prev_segment_cond_latents.device, prev_segment_cond_latents.dtype + ) + prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std + + # Prepare I2V mask + if task == "replace": + mask_video = 1 - mask_video + mask_video = mask_video.permute(0, 2, 1, 3, 4) + mask_video = mask_video.flatten(0, 1) + mask_video = F.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest") + mask_pixel_values = mask_video.unflatten(0, (batch_size, -1)) + mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) # output shape: [B, C = 1, T, H_lat, W_lat] + else: + mask_pixel_values = None + prev_segment_cond_mask = self.get_i2v_mask( + batch_size, + num_latent_frames, + latent_height, + latent_width, + mask_len=prev_segment_cond_frames if start_frame > 0 else 0, + mask_pixel_values=mask_pixel_values, + dtype=dtype, + device=device, + ) + + # Prepend cond I2V mask to prev segment cond latents along channel dimension + prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1) + return prev_segment_cond_latents + + def prepare_pose_latents( + self, + pose_video: torch.Tensor, + batch_size: int = 1, + sample_mode: int = "argmax", + generator: torch.Generator | list[torch.Generator] | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + # pose_video shape: (B, C, T, H, W) + pose_video = pose_video.to(device=device, dtype=dtype if dtype is not None else self.vae.dtype) + if isinstance(generator, list): + pose_latents = [ + retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator + ] + pose_latents = torch.cat(pose_latents) + else: + pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode) + # Standardize latents in preparation for Wan VAE encode + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(pose_latents.device, pose_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + pose_latents.device, pose_latents.dtype + ) + pose_latents = (pose_latents - latents_mean) * latents_recip_std + if pose_latents.shape[0] == 1 and batch_size > 1: + pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1) + return pose_latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 77, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + def pad_video_frames(self, frames: list[Any], num_target_frames: int) -> list[Any]: + """ + Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. The frame dimension + is assumed to be the first dimension. In the 1D case, we can visualize this strategy as follows: + + pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2] + """ + idx = 0 + flip = False + target_frames = [] + while len(target_frames) < num_target_frames: + target_frames.append(deepcopy(frames[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(frames) - 1: + flip = not flip + + return target_frames + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + pose_video: list[PIL.Image.Image], + face_video: list[PIL.Image.Image], + background_video: list[PIL.Image.Image] | None = None, + mask_video: list[PIL.Image.Image] | None = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 720, + width: int = 1280, + segment_frame_length: int = 77, + num_inference_steps: int = 20, + mode: str = "animate", + prev_segment_conditioning_frames: int = 1, + motion_encode_batch_size: int | None = None, + guidance_scale: float = 1.0, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input character image to condition the generation on. Must be an image, a list of images or a + `torch.Tensor`. + pose_video (`list[PIL.Image.Image]`): + The input pose video to condition the generation on. Must be a list of PIL images. + face_video (`list[PIL.Image.Image]`): + The input face video to condition the generation on. Must be a list of PIL images. + background_video (`list[PIL.Image.Image]`, *optional*): + When mode is `"replace"`, the input background video to condition the generation on. Must be a list of + PIL images. + mask_video (`list[PIL.Image.Image]`, *optional*): + When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL + images. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + mode (`str`, defaults to `"animation"`): + The mode of the generation. Choose between `"animate"` and `"replace"`. + prev_segment_conditioning_frames (`int`, defaults to `1`): + The number of frames from the previous video segment to be used for temporal guidance. Recommended to + be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer. + motion_encode_batch_size (`int`, *optional*): + The batch size for batched encoding of the face video via the motion encoder. This allows trading off + inference speed for lower memory usage by setting a smaller batch size. Will default to + `self.transformer.config.motion_encoder_batch_size` if not set. + height (`int`, defaults to `720`): + The height of the generated video. + width (`int`, defaults to `1280`): + The width of the generated video. + segment_frame_length (`int`, defaults to `77`): + The number of frames in each generated video segment. The total frames of video generated will be equal + to the number of frames in `pose_video`; we will generate the video in segments until we have hit this + length. In general, should be 4N + 1, where N is a non-negative integer. + num_inference_steps (`int`, defaults to `20`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan + Animate inference. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + mode, + prev_segment_conditioning_frames, + ) + + if segment_frame_length % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the" + f" nearest number." + ) + segment_frame_length = ( + segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + ) + segment_frame_length = max(segment_frame_length, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # As we generate in segments of `segment_frame_length`, set the target frame length to be the least multiple + # of the effective segment length greater than or equal to the length of `pose_video`. + cond_video_frames = len(pose_video) + effective_segment_length = segment_frame_length - prev_segment_conditioning_frames + last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length + if last_segment_frames == 0: + num_padding_frames = 0 + else: + num_padding_frames = effective_segment_length - last_segment_frames + num_target_frames = cond_video_frames + num_padding_frames + num_segments = num_target_frames // effective_segment_length + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Preprocess and encode the reference (character) image + image_height, image_width = self.video_processor.get_default_height_width(image) + if image_height != height or image_width != width: + logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})") + image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to( + device, dtype=torch.float32 + ) + + # Get CLIP features from the reference image + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 5. Encode conditioning videos (pose, face) + pose_video = self.pad_video_frames(pose_video, num_target_frames) + face_video = self.pad_video_frames(face_video, num_target_frames) + + # TODO: also support np.ndarray input (e.g. from decord like the original implementation?) + pose_video_width, pose_video_height = pose_video[0].size + if pose_video_height != height or pose_video_width != width: + logger.warning( + f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})" + ) + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + face_video_width, face_video_height = face_video[0].size + expected_face_size = self.transformer.config.motion_encoder_size + if face_video_width != expected_face_size or face_video_height != expected_face_size: + logger.warning( + f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size}," + f" {expected_face_size})" + ) + face_video = self.video_processor.preprocess_video( + face_video, height=expected_face_size, width=expected_face_size + ).to(device, dtype=torch.float32) + + if mode == "replace": + background_video = self.pad_video_frames(background_video, num_target_frames) + mask_video = self.pad_video_frames(mask_video, num_target_frames) + + background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables which stay constant for all inference segments + num_channels_latents = self.vae.config.z_dim + + # Get VAE-encoded latents of the reference (character) image + reference_image_latents = self.prepare_reference_image_latents( + image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device + ) + + # 8. Loop over video inference segments + start = 0 + end = segment_frame_length # Data space frames, not latent frames + all_out_frames = [] + out_frames = None + + for _ in range(num_segments): + assert start + prev_segment_conditioning_frames < cond_video_frames + + # Sample noisy latents from prior for the current inference segment + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=segment_frame_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents if start == 0 else None, # Only use pre-calculated latents for first segment + ) + + pose_video_segment = pose_video[:, :, start:end] + face_video_segment = face_video[:, :, start:end] + + face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1) + face_video_segment = face_video_segment.to(dtype=transformer_dtype) + + if start > 0: + prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach() + else: + prev_segment_cond_video = None + + if mode == "replace": + background_video_segment = background_video[:, :, start:end] + mask_video_segment = mask_video[:, :, start:end] + + background_video_segment = background_video_segment.expand( + batch_size * num_videos_per_prompt, -1, -1, -1, -1 + ) + mask_video_segment = mask_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1) + else: + background_video_segment = None + mask_video_segment = None + + pose_latents = self.prepare_pose_latents( + pose_video_segment, batch_size * num_videos_per_prompt, generator=generator, device=device + ) + pose_latents = pose_latents.to(dtype=transformer_dtype) + + prev_segment_cond_latents = self.prepare_prev_segment_cond_latents( + prev_segment_cond_video, + background_video=background_video_segment, + mask_video=mask_video_segment, + batch_size=batch_size * num_videos_per_prompt, + segment_frame_length=segment_frame_length, + start_frame=start, + height=height, + width=width, + prev_segment_cond_frames=prev_segment_conditioning_frames, + task=mode, + generator=generator, + device=device, + ) + + # Concatenate the reference latents in the frame dimension + reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2) + + # 8.1 Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Concatenate the reference image + prev segment conditioning in the channel dim + latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + # Blank out face for unconditional guidance (set all pixels to -1) + face_pixel_values_uncond = face_video_segment * 0 - 1 + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_pixel_values_uncond, + motion_encode_batch_size=motion_encode_batch_size, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = latents.to(self.vae.dtype) + # Destandardize latents in preparation for Wan VAE decoding + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_recip_std + latents_mean + # Skip the first latent frame (used for conditioning) + out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0] + + if start > 0: + out_frames = out_frames[:, :, prev_segment_conditioning_frames:] + all_out_frames.append(out_frames) + + start += effective_segment_length + end += effective_segment_length + + # Reset scheduler timesteps / state for next denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self._current_timestep = None + assert start + prev_segment_conditioning_frames >= cond_video_frames + + if not output_type == "latent": + video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..f669e9b1d0ecde901f3c33346af6673d4cb146aa --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -0,0 +1,838 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import PIL +import regex as re +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import numpy as np + >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import CLIPVisionModel + + >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers + >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanImageToVideoPipeline.from_pretrained( + ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> max_area = 480 * 832 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only + `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer: WanTransformer3DModel = None, + transformer_2: WanTransformer3DModel = None, + boundary_ratio: float | None = None, + expand_timesteps: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + transformer_2=transformer_2, + ) + self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps) + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image: PipelineImageInput, + device: torch.device | None = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.config.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + if self.config.expand_timesteps: + video_condition = image + + elif last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + if self.config.expand_timesteps: + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device + ) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + image_embeds: torch.Tensor | None = None, + last_image: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + patch_size = ( + self.transformer.config.patch_size + if self.transformer is not None + else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " + f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # only wan 2.1 i2v transformer accepts image_embeds + if self.transformer is not None and self.transformer.config.image_dim is not None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + + latents_outputs = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + if self.config.expand_timesteps: + # wan 2.2 5b i2v use firt_frame_mask to mask timesteps + latents, condition, first_frame_mask = latents_outputs + else: + latents, condition = latents_outputs + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + if self.config.expand_timesteps: + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(transformer_dtype) + + # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size) + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if self.config.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_vace.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..c016eec1b535008ab3c33668beba1bac09f6f453 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -0,0 +1,1043 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable + +import PIL.Image +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanVACETransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> import PIL.Image + >>> from diffusers import AutoencoderKLWan, WanVACEPipeline + >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + >>> from diffusers.utils import export_to_video, load_image + def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int): + first_img = first_img.resize((width, height)) + last_img = last_img.resize((width, height)) + frames = [] + frames.append(first_img) + # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays + # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to + # match the original code. + frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2)) + frames.append(last_img) + mask_black = PIL.Image.new("L", (width, height), 0) + mask_white = PIL.Image.new("L", (width, height), 255) + mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black] + return frames, mask + + >>> # Available checkpoints: Wan-AI/Wan2.1-VACE-1.3B-diffusers, Wan-AI/Wan2.1-VACE-14B-diffusers + >>> model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.to("cuda") + + >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> first_frame = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_frame = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png>>> " + ... ) + + >>> height = 512 + >>> width = 512 + >>> num_frames = 81 + >>> video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames) + + >>> output = pipe( + ... video=video, + ... mask=mask, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=num_frames, + ... num_inference_steps=30, + ... guidance_scale=5.0, + ... generator=torch.Generator().manual_seed(42), + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for controllable generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + transformer ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of + `transformer` or `transformer_2` must be provided. + transformer_2 ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of + `transformer` or `transformer_2` must be provided. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only the available transformer is used for the entire denoising process. + """ + + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: WanVACETransformer3DModel = None, + transformer_2: WanVACETransformer3DModel = None, + boundary_ratio: float | None = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + ) + self.register_to_config(boundary_ratio=boundary_ratio) + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + video=None, + mask=None, + reference_images=None, + guidance_scale_2=None, + ): + if self.transformer is not None: + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + elif self.transformer_2 is not None: + base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] + else: + raise ValueError( + "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" + ) + + if height % base != 0 or width % base != 0: + raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if video is not None: + if mask is not None: + if len(video) != len(mask): + raise ValueError( + f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" + " they have the same length." + ) + if reference_images is not None: + is_pil_image = isinstance(reference_images, PIL.Image.Image) + is_list_of_pil_images = isinstance(reference_images, list) and all( + isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images + ) + is_list_of_list_of_pil_images = isinstance(reference_images, list) and all( + isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img) + for ref_img in reference_images + ) + if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images): + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + if is_list_of_list_of_pil_images and len(reference_images) != 1: + raise ValueError( + "The pipeline only supports generating one video at a time at the moment. When passing a list " + "of list of reference images, where the outer list corresponds to the batch size and the inner " + "list corresponds to list of conditioning images per video, please make sure to only pass " + "one inner list of reference images (i.e., `[[, , ...]]`" + ) + elif mask is not None: + raise ValueError("`mask` can only be passed if `video` is passed as well.") + + def preprocess_conditions( + self, + video: list[PipelineImageInput] | None = None, + mask: list[PipelineImageInput] | None = None, + reference_images: PIL.Image.Image | list[PIL.Image.Image] | list[list[PIL.Image.Image]] | None = None, + batch_size: int = 1, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ): + if video is not None: + base = self.vae_scale_factor_spatial * ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) + video_height, video_width = self.video_processor.get_default_height_width(video[0]) + + if video_height * video_width > height * width: + scale = min(width / video_width, height / video_height) + video_height, video_width = int(video_height * scale), int(video_width * scale) + + if video_height % base != 0 or video_width % base != 0: + logger.warning( + f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. " + ) + video_height = (video_height // base) * base + video_width = (video_width // base) * base + + assert video_height * video_width <= height * width + + video = self.video_processor.preprocess_video(video, video_height, video_width) + image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling) + else: + video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device) + image_size = (height, width) # Use the height/width provider by user + + if mask is not None: + mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1]) + mask = torch.clamp((mask + 1) / 2, min=0, max=1) + else: + mask = torch.ones_like(video) + + video = video.to(dtype=dtype, device=device) + mask = mask.to(dtype=dtype, device=device) + + # Make a list of list of images where the outer list corresponds to video batch size and the inner list + # corresponds to list of conditioning images per video + if reference_images is None or isinstance(reference_images, PIL.Image.Image): + reference_images = [[reference_images] for _ in range(video.shape[0])] + elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image): + reference_images = [reference_images] + elif ( + isinstance(reference_images, (list, tuple)) + and isinstance(next(iter(reference_images)), list) + and isinstance(next(iter(reference_images[0])), PIL.Image.Image) + ): + reference_images = reference_images + else: + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match." + ) + + ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images] + if any(l != ref_images_lengths[0] for l in ref_images_lengths): + raise ValueError( + f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this " + "may be added in the future." + ) + + reference_images_preprocessed = [] + for i, reference_images_batch in enumerate(reference_images): + preprocessed_images = [] + for j, image in enumerate(reference_images_batch): + if image is None: + continue + image = self.video_processor.preprocess(image, None, None) + img_height, img_width = image.shape[-2:] + scale = min(image_size[0] / img_height, image_size[1] / img_width) + new_height, new_width = int(img_height * scale), int(img_width * scale) + resized_image = torch.nn.functional.interpolate( + image, size=(new_height, new_width), mode="bilinear", align_corners=False + ).squeeze(0) # [C, H, W] + top = (image_size[0] - new_height) // 2 + left = (image_size[1] - new_width) // 2 + canvas = torch.ones(3, *image_size, device=device, dtype=dtype) + canvas[:, top : top + new_height, left : left + new_width] = resized_image + preprocessed_images.append(canvas) + reference_images_preprocessed.append(preprocessed_images) + + return video, mask, reference_images_preprocessed + + def prepare_video_latents( + self, + video: torch.Tensor, + mask: torch.Tensor, + reference_images: list[list[torch.Tensor]] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + device: torch.device | None = None, + ) -> torch.Tensor: + device = device or self._execution_device + + if isinstance(generator, list): + # TODO: support this + raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.") + + if reference_images is None: + # For each batch of video, we set no re + # ference image (as one or more can be passed by user) + reference_images = [[None] for _ in range(video.shape[0])] + else: + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match." + ) + + if video.shape[0] != 1: + # TODO: support this + raise ValueError( + "Generating with more than one video is not yet supported. This may be supported in the future." + ) + + vae_dtype = self.vae.dtype + video = video.to(dtype=vae_dtype) + + latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + + if mask is None: + latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) + latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype) + else: + mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) + inactive = video * (1 - mask) + reactive = video * mask + inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax") + reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax") + inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype) + reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype) + latents = torch.cat([inactive, reactive], dim=1) + + latent_list = [] + for latent, reference_images_batch in zip(latents, reference_images): + for reference_image in reference_images_batch: + assert reference_image.ndim == 3 + reference_image = reference_image.to(dtype=vae_dtype) + reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W] + reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax") + reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype) + reference_latent = reference_latent.squeeze(0) # [C, 1, H, W] + reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0) + latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) + latent_list.append(latent) + return torch.stack(latent_list) + + def prepare_masks( + self, + mask: torch.Tensor, + reference_images: list[torch.Tensor] | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + ) -> torch.Tensor: + if isinstance(generator, list): + # TODO: support this + raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.") + + if reference_images is None: + # For each batch of video, we set no reference image (as one or more can be passed by user) + reference_images = [[None] for _ in range(mask.shape[0])] + else: + if mask.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `mask` {mask.shape[0]} and length of `reference_images` {len(reference_images)} does not match." + ) + + if mask.shape[0] != 1: + # TODO: support this + raise ValueError( + "Generating with more than one video is not yet supported. This may be supported in the future." + ) + + transformer_patch_size = ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) + + mask_list = [] + for mask_, reference_images_batch in zip(mask, reference_images): + num_channels, num_frames, height, width = mask_.shape + new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal + new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size + new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size + mask_ = mask_[0, :, :, :] + mask_ = mask_.view( + num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial + ) + mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width] + mask_ = torch.nn.functional.interpolate( + mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" + ).squeeze(0) + num_ref_images = len(reference_images_batch) + if num_ref_images > 0: + mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :]) + mask_ = torch.cat([mask_padding, mask_], dim=1) + mask_list.append(mask_) + return torch.stack(mask_list) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + video: list[PipelineImageInput] | None = None, + mask: list[PipelineImageInput] | None = None, + reference_images: list[PipelineImageInput] | None = None, + conditioning_scale: float | list[float] | torch.Tensor = 1.0, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: float | None = None, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + video (`list[PIL.Image.Image]`, *optional*): + The input video or videos to be used as a starting point for the generation. The video should be a list + of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one + video at a time. + mask (`list[PIL.Image.Image]`, *optional*): + The input mask defines which video regions to condition on and which to generate. Black areas in the + mask indicate conditioning regions, while white areas indicate regions for generation. The mask should + be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video + at a time. + reference_images (`list[PIL.Image.Image]`, *optional*): + A list of one or more reference images as extra conditioning for the generation. For example, if you + are trying to inpaint a video to change the character, you can pass reference images of the new + character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582) + and original [user + guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md) + for a full list of supported tasks and use cases. + conditioning_scale (`float`, `list[float]`, `torch.Tensor`, defaults to `1.0`): + The conditioning scale to be applied when adding the control conditioning latent stream to the + denoising latent stream in each control layer of the model. If a float is provided, it will be applied + uniformly to all layers. If a list or tensor is provided, it should have the same length as the number + of control layers in the model (`len(transformer.config.vace_layers)`). + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Simplification of implementation for now + if prompt is not None and not isinstance(prompt, str): + raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.") + if num_videos_per_prompt != 1: + raise ValueError( + "Generating multiple videos per prompt is not yet supported. This may be supported in the future." + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + video, + mask, + reference_images, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + + vace_layers = ( + self.transformer.config.vace_layers + if self.transformer is not None + else self.transformer_2.config.vace_layers + ) + if isinstance(conditioning_scale, (int, float)): + conditioning_scale = [conditioning_scale] * len(vace_layers) + if isinstance(conditioning_scale, list): + if len(conditioning_scale) != len(vace_layers): + raise ValueError( + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." + ) + conditioning_scale = torch.tensor(conditioning_scale) + if isinstance(conditioning_scale, torch.Tensor): + if conditioning_scale.size(0) != len(vace_layers): + raise ValueError( + f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}." + ) + conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + video, mask, reference_images = self.preprocess_conditions( + video, + mask, + reference_images, + batch_size, + height, + width, + num_frames, + torch.float32, + device, + ) + num_reference_images = len(reference_images[0]) + + conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device) + mask = self.prepare_masks(mask, reference_images, generator) + conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) + conditioning_latents = conditioning_latents.to(transformer_dtype) + + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames + num_reference_images * self.vae_scale_factor_temporal, + torch.float32, + device, + generator, + latents, + ) + + if conditioning_latents.shape[2] != latents.shape[2]: + logger.warning( + "The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected." + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + control_hidden_states=conditioning_latents, + control_hidden_states_scale=conditioning_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + control_hidden_states=conditioning_latents, + control_hidden_states_scale=conditioning_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents[:, :, num_reference_images:] + latents = latents.to(vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_video2video.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_video2video.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7c5297f4c480a13e94d7f9779ee055b26dcfae --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -0,0 +1,729 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +from typing import Any, Callable + +import regex as re +import torch +from PIL import Image +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import WanPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video, load_video + >>> from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline + >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + + >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers + >>> model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = WanVideoToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.to("cuda") + + >>> prompt = "A robot standing on a mountain top. The sun is setting in the background" + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> output = pipe( + ... video=video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=480, + ... width=720, + ... guidance_scale=5.0, + ... strength=0.7, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for video-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: str | list[str] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + max_sequence_length: int = 226, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + video=None, + latents=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + + def prepare_latents( + self, + video: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + timestep: torch.Tensor | None = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = ( + (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + ) + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + init_latents = (init_latents - latents_mean) * latents_std + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if hasattr(self.scheduler, "add_noise"): + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + latents = self.scheduler.scale_noise(init_latents, timestep, noise) + else: + latents = latents.to(device) + + return latents + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, timesteps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: list[Image.Image] = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] = None, + height: int = 480, + width: int = 832, + num_inference_steps: int = 50, + timesteps: list[int] | None = None, + guidance_scale: float = 5.0, + strength: float = 0.8, + num_videos_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str | None = "np", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + strength (`float`, defaults to `0.8`): + Higher strength leads to more differences between original image and generated video. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length of the text encoder. If the prompt is longer than this, it will be + truncated. If the prompt is shorter, it will be padded to this length. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + video, + latents, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + self._num_timesteps = len(timesteps) + + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + latent_timestep, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ddb852d1931558fe0948e81e16cf9a92fc2a114b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/__init__.py @@ -0,0 +1,56 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_paella_vq_model"] = ["PaellaVQModel"] + _import_structure["modeling_wuerstchen_diffnext"] = ["WuerstchenDiffNeXt"] + _import_structure["modeling_wuerstchen_prior"] = ["WuerstchenPrior"] + _import_structure["pipeline_wuerstchen"] = ["WuerstchenDecoderPipeline"] + _import_structure["pipeline_wuerstchen_combined"] = ["WuerstchenCombinedPipeline"] + _import_structure["pipeline_wuerstchen_prior"] = ["DEFAULT_STAGE_C_TIMESTEPS", "WuerstchenPriorPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modeling_paella_vq_model import PaellaVQModel + from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + from .modeling_wuerstchen_prior import WuerstchenPrior + from .pipeline_wuerstchen import WuerstchenDecoderPipeline + from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline + from .pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPriorPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..932c7ac618f6a3ed7c110a7fd7e205f035f789b9 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py @@ -0,0 +1,171 @@ +# Copyright (c) 2022 Dominic Rampas MIT License +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer +from ...models.modeling_utils import ModelMixin +from ...models.vq_model import VQEncoderOutput +from ...utils.accelerate_utils import apply_forward_hook + + +class MixingResidualBlock(nn.Module): + """ + Residual block with mixing used by Paella's VQ-VAE. + """ + + def __init__(self, inp_channels, embed_dim): + super().__init__() + # depthwise + self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) + ) + + # channelwise + self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + def forward(self, x): + mods = self.gammas + x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] + x = x + self.depthwise(x_temp) * mods[2] + x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + return x + + +class PaellaVQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from Paella model. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. + levels (int, *optional*, defaults to 2): Number of levels in the model. + bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. + embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model. + latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model. + num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. + scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_down_scale_factor: int = 2, + levels: int = 2, + bottleneck_blocks: int = 12, + embed_dim: int = 384, + latent_channels: int = 4, + num_vq_embeddings: int = 8192, + scale_factor: float = 0.3764, + ): + super().__init__() + + c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(up_down_scale_factor), + nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append( + nn.Sequential( + nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1 + ) + ) + self.down_blocks = nn.Sequential(*down_blocks) + + # Vector Quantizer + self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) + + # Decoder blocks + up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d( + c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 + ) + ) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), + nn.PixelShuffle(up_down_scale_factor), + ) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.in_block(x) + h = self.down_blocks(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.Tensor, force_not_quantize: bool = True, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + if not force_not_quantize: + quant, _, _ = self.vquantizer(h) + else: + quant = h + + x = self.up_blocks(quant) + dec = self.out_block(x) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py new file mode 100644 index 0000000000000000000000000000000000000000..73e71b3076fbca259ae76138bc4ab3d3797e2755 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn + +from ...models.attention_processor import Attention + + +class WuerstchenLayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + return x.permute(0, 3, 1, 2) + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep): + super().__init__() + + self.mapper = nn.Linear(c_timestep, c * 2) + + def forward(self, x, t): + a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1) + return x * (1 + a) + b + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + + self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c) + ) + + def forward(self, x, x_skip=None): + x_res = x + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) + x = self.channelwise(x).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + + self.self_attn = self_attn + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c)) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py new file mode 100644 index 0000000000000000000000000000000000000000..77ae597655d164e25add03a641a6b8eca5396200 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm + + +class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + c_in=4, + c_out=4, + c_r=64, + patch_size=2, + c_cond=1024, + c_hidden=[320, 640, 1280, 1280], + nhead=[-1, 10, 20, 20], + blocks=[4, 4, 14, 4], + level_config=["CT", "CTA", "CTA", "CTA"], + inject_effnet=[False, True, True, True], + effnet_embd=16, + clip_embd=1024, + kernel_size=3, + dropout=0.1, + ): + super().__init__() + self.c_r = c_r + self.c_cond = c_cond + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + + # CONDITIONING + self.clip_mapper = nn.Linear(clip_embd, c_cond) + self.effnet_mappers = nn.ModuleList( + [ + nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None + for inject in inject_effnet + list(reversed(inject_effnet)) + ] + ) + self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0): + if block_type == "C": + return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "A": + return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout) + elif block_type == "T": + return TimestepBlock(c_hidden, c_r) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + for i in range(len(c_hidden)): + down_block = nn.ModuleList() + if i > 0: + down_block.append( + nn.Sequential( + WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), + ) + ) + for _ in range(blocks[i]): + for block_type in level_config[i]: + c_skip = c_cond if inject_effnet[i] else 0 + down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + self.down_blocks.append(down_block) + + # -- up blocks + self.up_blocks = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + up_block = nn.ModuleList() + for j in range(blocks[i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + c_skip += c_cond if inject_effnet[i] else 0 + up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i])) + if i > 0: + up_block.append( + nn.Sequential( + WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6), + nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), + ) + ) + self.up_blocks.append(up_block) + + # OUTPUT + self.clf = nn.Sequential( + WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + self.apply(self._init_weights) + + def _init_weights(self, m): + # General init + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + for mapper in self.effnet_mappers: + if mapper is not None: + nn.init.normal_(mapper.weight, std=0.02) # conditionings + nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + nn.init.constant_(self.clf[1].weight, 0) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, ResBlockStageB): + block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks)) + elif isinstance(block, TimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def gen_c_embeddings(self, clip): + clip = self.clip_mapper(clip) + clip = self.seq_norm(clip) + return clip + + def _down_encode(self, x, r_embed, effnet, clip=None): + level_outputs = [] + for i, down_block in enumerate(self.down_blocks): + effnet_c = None + for block in down_block: + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[i] is not None: + dtype = effnet.dtype + effnet_c = self.effnet_mappers[i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ).to(dtype) + ) + skip = effnet_c if self.effnet_mappers[i] is not None else None + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, effnet, clip=None): + x = level_outputs[0] + for i, up_block in enumerate(self.up_blocks): + effnet_c = None + for j, block in enumerate(up_block): + if isinstance(block, ResBlockStageB): + if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None: + dtype = effnet.dtype + effnet_c = self.effnet_mappers[len(self.down_blocks) + i]( + nn.functional.interpolate( + effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True + ).to(dtype) + ) + skip = level_outputs[i] if j == 0 and i > 0 else None + if effnet_c is not None: + if skip is not None: + skip = torch.cat([skip, effnet_c], dim=1) + else: + skip = effnet_c + x = block(x, skip) + elif isinstance(block, AttnBlock): + x = block(x, clip) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + return x + + def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True): + if x_cat is not None: + x = torch.cat([x, x_cat], dim=1) + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r) + if clip is not None: + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x_in = x + x = self.embedding(x) + level_outputs = self._down_encode(x, r_embed, effnet, clip) + x = self._up_decode(level_outputs, r_embed, effnet, clip) + a, b = self.clf(x).chunk(2, dim=1) + b = b.sigmoid() * (1 - eps * 2) + eps + if return_noise: + return (x_in - a) / b + else: + return a, b + + +class ResBlockStageB(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) + self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(dropout), + nn.Linear(c * 4, c), + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdd50871b43e6e4930451f2090a30fb95e14501 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023 Dominic Rampas MIT License +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...models.attention import AttentionMixin +from ...models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttnAddedKVProcessor, + AttnProcessor, +) +from ...models.modeling_utils import ModelMixin +from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm + + +class WuerstchenPrior(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + unet_name = "prior" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): + super().__init__() + + self.c_r = c_r + self.projection = nn.Conv2d(c_in, c, kernel_size=1) + self.cond_mapper = nn.Sequential( + nn.Linear(c_cond, c), + nn.LeakyReLU(0.2), + nn.Linear(c, c), + ) + + self.blocks = nn.ModuleList() + for _ in range(depth): + self.blocks.append(ResBlock(c, dropout=dropout)) + self.blocks.append(TimestepBlock(c, c_r)) + self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout)) + self.out = nn.Sequential( + WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6), + nn.Conv2d(c, c_in * 2, kernel_size=1), + ) + + self.gradient_checkpointing = False + self.set_default_attn_processor() + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode="constant") + return emb.to(dtype=r.dtype) + + def forward(self, x, r, c): + x_in = x + x = self.projection(x) + c_embed = self.cond_mapper(c) + r_embed = self.gen_r_embedding(r) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = self._gradient_checkpointing_func(block, x, c_embed) + elif isinstance(block, TimestepBlock): + x = self._gradient_checkpointing_func(block, x, r_embed) + else: + x = self._gradient_checkpointing_func(block, x) + else: + for block in self.blocks: + if isinstance(block, AttnBlock): + x = block(x, c_embed) + elif isinstance(block, TimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + a, b = self.out(x).chunk(2, dim=1) + return (x_in - a) / ((1 - b).abs() + 1e-5) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py new file mode 100644 index 0000000000000000000000000000000000000000..cce05c1892013d5ec3dd69adc3b190abcefb7aad --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -0,0 +1,449 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput +from .modeling_paella_vq_model import PaellaVQModel +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain("warp-ai/wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt) + ``` +""" + + +class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Pipeline for generating images from the Wuerstchen model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The CLIP tokenizer. + text_encoder (`CLIPTextModel`): + The CLIP text encoder. + decoder ([`WuerstchenDiffNeXt`]): + The WuerstchenDiffNeXt unet decoder. + vqgan ([`PaellaVQModel`]): + The VQGAN model. + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_dim_scale (float, `optional`, defaults to 10.67): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + """ + + model_cpu_offload_seq = "text_encoder->decoder->vqgan" + _callback_tensor_inputs = [ + "latents", + "text_encoder_hidden_states", + "negative_prompt_embeds", + "image_embeddings", + ] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: WuerstchenDiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + latent_dim_scale: float = 10.67, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + self.register_to_config(latent_dim_scale=latent_dim_scale) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device)) + text_encoder_hidden_states = text_encoder_output.last_hidden_state + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_text_encoder_hidden_states = None + if do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + return text_encoder_hidden_states, uncond_text_encoder_hidden_states + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image_embeddings: torch.Tensor | list[torch.Tensor], + prompt: str | list[str] = None, + num_inference_steps: int = 12, + timesteps: list[float] | None = None, + guidance_scale: float = 0.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + image_embedding (`torch.Tensor` or `list[torch.Tensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + num_inference_steps (`int`, *optional*, defaults to 12): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image + embeddings. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # 0. Define commonly used variables + device = self._execution_device + dtype = self.decoder.dtype + self._guidance_scale = guidance_scale + + # 1. Check inputs. Raise error if not correct + if not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if self.do_classifier_free_guidance: + if negative_prompt is not None and not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) + + if isinstance(image_embeddings, list): + image_embeddings = torch.cat(image_embeddings, dim=0) + if isinstance(image_embeddings, np.ndarray): + image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype) + if not isinstance(image_embeddings, torch.Tensor): + raise TypeError( + f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}." + ) + + if not isinstance(num_inference_steps, int): + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) + + # 2. Encode caption + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + image_embeddings.size(0) * num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + ) + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + + # 3. Determine latent shape of latents + latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) + latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale) + latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width) + + # 4. Prepare and set timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + ratio = t.expand(latents.size(0)).to(dtype) + # 7. Denoise latents + predicted_latents = self.decoder( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, + effnet=effnet, + clip=text_encoder_hidden_states, + ) + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2) + predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_latents, + timestep=ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + image_embeddings = callback_outputs.pop("image_embeddings", image_embeddings) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type not in ["pt", "np", "pil", "latent"]: + raise ValueError( + f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" + ) + + if not output_type == "latent": + # 10. Scale and decode the image latents with vq-vae + latents = self.vqgan.config.scale_factor * latents + images = self.vqgan.decode(latents).sample.clamp(0, 1) + if output_type == "np": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() + elif output_type == "pil": + images = images.permute(0, 2, 3, 1).cpu().float().numpy() + images = self.numpy_to_pil(images) + else: + images = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return images + return ImagePipelineOutput(images) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..16300a7c71d2446b381afc02951d7ac4c3d8ff32 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -0,0 +1,307 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import deprecate, replace_example_docstring +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline +from .modeling_paella_vq_model import PaellaVQModel +from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt +from .modeling_wuerstchen_prior import WuerstchenPrior +from .pipeline_wuerstchen import WuerstchenDecoderPipeline +from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline + + +TEXT2IMAGE_EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusions import WuerstchenCombinedPipeline + + >>> pipe = WuerstchenCombinedPipeline.from_pretrained("warp-ai/Wuerstchen", torch_dtype=torch.float16).to( + ... "cuda" + ... ) + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> images = pipe(prompt=prompt) + ``` +""" + + +class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline): + """ + Combined Pipeline for text-to-image generation using Wuerstchen + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer (`CLIPTokenizer`): + The decoder tokenizer to be used for text inputs. + text_encoder (`CLIPTextModel`): + The decoder text encoder to be used for text inputs. + decoder (`WuerstchenDiffNeXt`): + The decoder model to be used for decoder image generation pipeline. + scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for decoder image generation pipeline. + vqgan (`PaellaVQModel`): + The VQGAN model to be used for decoder image generation pipeline. + prior_tokenizer (`CLIPTokenizer`): + The prior tokenizer to be used for text inputs. + prior_text_encoder (`CLIPTextModel`): + The prior text encoder to be used for text inputs. + prior_prior (`WuerstchenPrior`): + The prior model to be used for prior pipeline. + prior_scheduler (`DDPMWuerstchenScheduler`): + The scheduler to be used for prior pipeline. + """ + + _last_supported_version = "0.33.1" + _load_connected_pipes = True + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + decoder: WuerstchenDiffNeXt, + scheduler: DDPMWuerstchenScheduler, + vqgan: PaellaVQModel, + prior_tokenizer: CLIPTokenizer, + prior_text_encoder: CLIPTextModel, + prior_prior: WuerstchenPrior, + prior_scheduler: DDPMWuerstchenScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + prior_prior=prior_prior, + prior_text_encoder=prior_text_encoder, + prior_tokenizer=prior_tokenizer, + prior_scheduler=prior_scheduler, + ) + self.prior_pipe = WuerstchenPriorPipeline( + prior=prior_prior, + text_encoder=prior_text_encoder, + tokenizer=prior_tokenizer, + scheduler=prior_scheduler, + ) + self.decoder_pipe = WuerstchenDecoderPipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + decoder=decoder, + scheduler=scheduler, + vqgan=vqgan, + ) + + def enable_xformers_memory_efficient_attention(self, attention_op: Callable | None = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + + def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) + + def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch.device | str = None): + r""" + Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 + Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a + GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis. + Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device) + + def progress_bar(self, iterable=None, total=None): + self.prior_pipe.progress_bar(iterable=iterable, total=total) + self.decoder_pipe.progress_bar(iterable=iterable, total=total) + + def set_progress_bar_config(self, **kwargs): + self.prior_pipe.set_progress_bar_config(**kwargs) + self.decoder_pipe.set_progress_bar_config(**kwargs) + + @torch.no_grad() + @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + height: int = 512, + width: int = 512, + prior_num_inference_steps: int = 60, + prior_timesteps: list[float] | None = None, + prior_guidance_scale: float = 4.0, + num_inference_steps: int = 12, + decoder_timesteps: list[float] | None = None, + decoder_guidance_scale: float = 0.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + prior_callback_on_step_end: Callable[[int, int], None] | None = None, + prior_callback_on_step_end_tensor_inputs: list[str] = ["latents"], + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation for the prior and decoder. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* + prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` + input argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + prior_num_inference_steps (`int | dict[float, int]`, *optional*, defaults to 60): + The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. For more specific timestep spacing, you can pass customized + `prior_timesteps` + num_inference_steps (`int`, *optional*, defaults to 12): + The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at + the expense of slower inference. For more specific timestep spacing, you can pass customized + `timesteps` + prior_timesteps (`list[float]`, *optional*): + Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced + `prior_num_inference_steps` timesteps are used. Must be in descending order. + decoder_timesteps (`list[float]`, *optional*): + Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced + `num_inference_steps` timesteps are used. Must be in descending order. + decoder_guidance_scale (`float`, *optional*, defaults to 0.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + prior_callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep: + int, callback_kwargs: Dict)`. + prior_callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the + list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in + the `._callback_tensor_inputs` attribute of your pipeline class. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + prior_kwargs = {} + if kwargs.get("prior_callback", None) is not None: + prior_kwargs["callback"] = kwargs.pop("prior_callback") + deprecate( + "prior_callback", + "1.0.0", + "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + if kwargs.get("prior_callback_steps", None) is not None: + deprecate( + "prior_callback_steps", + "1.0.0", + "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`", + ) + prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps") + + prior_outputs = self.prior_pipe( + prompt=prompt if prompt_embeds is None else None, + height=height, + width=width, + num_inference_steps=prior_num_inference_steps, + timesteps=prior_timesteps, + guidance_scale=prior_guidance_scale, + negative_prompt=negative_prompt if negative_prompt_embeds is None else None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + output_type="pt", + return_dict=False, + callback_on_step_end=prior_callback_on_step_end, + callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs, + **prior_kwargs, + ) + image_embeddings = prior_outputs[0] + + outputs = self.decoder_pipe( + image_embeddings=image_embeddings, + prompt=prompt if prompt is not None else "", + num_inference_steps=num_inference_steps, + timesteps=decoder_timesteps, + guidance_scale=decoder_guidance_scale, + negative_prompt=negative_prompt, + generator=generator, + output_type=output_type, + return_dict=return_dict, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + **kwargs, + ) + + return outputs diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..e79fcf8378aa05dc8ee7acd1678ca5476284b9bd --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -0,0 +1,528 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import ceil +from typing import Callable + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...loaders import StableDiffusionLoraLoaderMixin +from ...schedulers import DDPMWuerstchenScheduler +from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .modeling_wuerstchen_prior import WuerstchenPrior + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import WuerstchenPriorPipeline + + >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained( + ... "warp-ai/wuerstchen-prior", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet" + >>> prior_output = pipe(prompt) + ``` +""" + + +@dataclass +class WuerstchenPriorPipelineOutput(BaseOutput): + """ + Output class for WuerstchenPriorPipeline. + + Args: + image_embeddings (`torch.Tensor` or `np.ndarray`) + Prior image embeddings for text prompt + + """ + + image_embeddings: torch.Tensor | np.ndarray + + +class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin): + """ + Pipeline for generating image prior for Wuerstchen. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + + Args: + prior ([`Prior`]): + The canonical unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`DDPMWuerstchenScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + latent_mean ('float', *optional*, defaults to 42.0): + Mean value for latent diffusers. + latent_std ('float', *optional*, defaults to 1.0): + Standard value for latent diffusers. + resolution_multiple ('float', *optional*, defaults to 42.67): + Default resolution for multiple images generated. + """ + + unet_name = "prior" + text_encoder_name = "text_encoder" + model_cpu_offload_seq = "text_encoder->prior" + _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"] + _lora_loadable_modules = ["prior", "text_encoder"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prior: WuerstchenPrior, + scheduler: DDPMWuerstchenScheduler, + latent_mean: float = 42.0, + latent_std: float = 1.0, + resolution_multiple: float = 42.67, + ) -> None: + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + prior=prior, + scheduler=scheduler, + ) + self.register_to_config( + latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def encode_prompt( + self, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt=None, + negative_prompt=None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + attention_mask = attention_mask[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask.to(device) + ) + prompt_embeds = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if negative_prompt_embeds is None and do_classifier_free_guidance: + uncond_tokens: list[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds_text_encoder_output = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + ) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.last_hidden_state + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + # done duplicates + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + num_inference_steps, + do_classifier_free_guidance, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if not isinstance(num_inference_steps, int): + raise TypeError( + f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\ + In Case you want to provide explicit timesteps, please use the 'timesteps' argument." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] | None = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 60, + timesteps: list[float] = None, + guidance_scale: float = 8.0, + negative_prompt: str | list[str] | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + output_type: str | None = "pt", + return_dict: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + **kwargs, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 60): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by + setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are + closely linked to the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `decoder_guidance_scale` is less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated image embeddings. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # 0. Define commonly used variables + device = self._execution_device + self._guidance_scale = guidance_scale + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 1. Check inputs. Raise error if not correct + if prompt is not None and not isinstance(prompt, list): + if isinstance(prompt, str): + prompt = [prompt] + else: + raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.") + + if self.do_classifier_free_guidance: + if negative_prompt is not None and not isinstance(negative_prompt, list): + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + else: + raise TypeError( + f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}." + ) + + self.check_inputs( + prompt, + negative_prompt, + num_inference_steps, + self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 2. Encode caption + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_encoder_hidden_states = ( + torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds + ) + + # 3. Determine latent shape of image embeddings + dtype = text_encoder_hidden_states.dtype + latent_height = ceil(height / self.config.resolution_multiple) + latent_width = ceil(width / self.config.resolution_multiple) + num_channels = self.prior.config.c_in + effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width) + + # 4. Prepare and set timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler) + + # 6. Run denoising loop + self._num_timesteps = len(timesteps[:-1]) + for i, t in enumerate(self.progress_bar(timesteps[:-1])): + ratio = t.expand(latents.size(0)).to(dtype) + + # 7. Denoise image embeddings + predicted_image_embedding = self.prior( + torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio, + c=text_encoder_hidden_states, + ) + + # 8. Check for classifier free guidance and apply it + if self.do_classifier_free_guidance: + predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2) + predicted_image_embedding = torch.lerp( + predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale + ) + + # 9. Renoise latents to next timestep + latents = self.scheduler.step( + model_output=predicted_image_embedding, + timestep=ratio, + sample=latents, + generator=generator, + ).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + text_encoder_hidden_states = callback_outputs.pop( + "text_encoder_hidden_states", text_encoder_hidden_states + ) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + # 10. Denormalize the latents + latents = latents * self.config.latent_mean - self.config.latent_std + + # Offload all models + self.maybe_free_model_hooks() + + if output_type == "np": + latents = latents.cpu().float().numpy() + + if not return_dict: + return (latents,) + + return WuerstchenPriorPipelineOutput(latents) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__init__.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14629a6e216058c1800de8a2dbfe28f4ec6d2d66 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] + _import_structure["pipeline_z_image"] = ["ZImagePipeline"] + _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] + _import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] + _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] + _import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"] + _import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import ZImagePipelineOutput + from .pipeline_z_image import ZImagePipeline + from .pipeline_z_image_controlnet import ZImageControlNetPipeline + from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline + from .pipeline_z_image_img2img import ZImageImg2ImgPipeline + from .pipeline_z_image_inpaint import ZImageInpaintPipeline + from .pipeline_z_image_omni import ZImageOmniPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/__init__.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2526e19eedbdc6bebfdc81ff710ccafa0be0c47a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/__init__.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_output.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_output.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75bd00cbd299a56a121306cdf13e87064443fa05 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_output.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3690c077bb032ecc2a0fc87203f2ccd878ae5e4c Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_controlnet.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_controlnet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2cc9a4548f59ccbcb1d8a5edee4e2bc150a6452 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_controlnet.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_controlnet_inpaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_controlnet_inpaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2b5595e787419ec410d842201c1de6e4fbc13ba Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_controlnet_inpaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_img2img.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_img2img.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2112702533bc78363298a522558701e3d1737a0 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_img2img.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_inpaint.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_inpaint.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe717e2a39fb73eb398dc1e461742e44d0e7da8 Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_inpaint.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_omni.cpython-312.pyc b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_omni.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f76755b97bcf0604ca1606419f435218abe6294a Binary files /dev/null and b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/__pycache__/pipeline_z_image_omni.cpython-312.pyc differ diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_output.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_output.py new file mode 100644 index 0000000000000000000000000000000000000000..5cbcbee4bfab7a28f33dcbffd0863749320d763a --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_output.py @@ -0,0 +1,34 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class ZImagePipelineOutput(BaseOutput): + """ + Output class for Z-Image pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: list[PIL.Image.Image, np.ndarray] diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..959368ec1cd102c3494b082252a39dfdfa4ce485 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image.py @@ -0,0 +1,594 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1e49737bb5b01cbc4d83e4c0e107a1e61b103606 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -0,0 +1,726 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageControlNetPipeline + >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image + >>> from huggingface_hub import hf_hub_download + + >>> controlnet = ZImageControlNetModel.from_single_file( + ... hf_hub_download( + ... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union", + ... filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", + ... ), + ... torch_dtype=torch.bfloat16, + ... ) + + >>> # 2.1 + >>> # controlnet = ZImageControlNetModel.from_single_file( + >>> # hf_hub_download( + >>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + >>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", + >>> # ), + >>> # torch_dtype=torch.bfloat16, + >>> # ) + + >>> # 2.0 + >>> # controlnet = ZImageControlNetModel.from_single_file( + >>> # hf_hub_download( + >>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + >>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors", + >>> # ), + >>> # torch_dtype=torch.bfloat16, + >>> # ) + + >>> pipe = ZImageControlNetPipeline.from_pretrained( + ... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> control_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true" + ... ) + >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。" + >>> image = pipe( + ... prompt, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.75, + ... height=1728, + ... width=992, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(43), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 0.75, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + if num_channels_latents != self.controlnet.config.control_in_dim: + # For model version 2.0 + control_image = torch.cat( + [ + control_image, + torch.zeros( + control_image.shape[0], + self.controlnet.config.control_in_dim - num_channels_latents, + *control_image.shape[2:], + ).to(device=control_image.device, dtype=control_image.dtype), + ], + dim=1, + ) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + control_image_input = control_image.repeat(2, 1, 1, 1, 1) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + control_image_input = control_image + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_image_input, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..09f9b23954585e2ee68578affc874721f7b6942b --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -0,0 +1,748 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageControlNetInpaintPipeline + >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image + >>> from huggingface_hub import hf_hub_download + + >>> controlnet = ZImageControlNetModel.from_single_file( + ... hf_hub_download( + ... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + ... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", + ... ), + ... torch_dtype=torch.bfloat16, + ... ) + + >>> # 2.0 + >>> # controlnet = ZImageControlNetModel.from_single_file( + >>> # hf_hub_download( + >>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + >>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors", + >>> # ), + >>> # torch_dtype=torch.bfloat16, + >>> # ) + + >>> pipe = ZImageControlNetInpaintPipeline.from_pretrained( + ... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/inpaint.jpg?download=true" + ... ) + >>> mask_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/mask.jpg?download=true" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/pose.jpg?download=true" + ... ) + >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,画面为全身竖构图,身体微微侧向右侧,左手自然下垂,右臂弯曲扶在腰间,她的手指清晰可见,站姿放松而略带羞涩。她身穿轻盈的白色连衣裙,裙摆在海风中轻轻飘动,布料半透、质感柔软。女子拥有一头鲜艳的及腰紫色长发,被海风吹起,在身侧轻盈飞舞,发间系着一个精致的黑色蝴蝶结,与发色形成对比。她面容清秀,眉目精致,肤色白皙细腻,表情温柔略显羞涩,微微低头,眼神静静望向远处的海平线,流露出甜美的青春气息与若有所思的神情。背景是辽阔无垠的海洋与蔚蓝天空,阳光从侧前方洒下,海面波光粼粼,泛着温暖的金色光晕,天空清澈明亮,云朵稀薄,整体色调清新唯美。" + >>> image = pipe( + ... prompt, + ... image=image, + ... mask_image=mask_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.75, + ... height=1728, + ... width=992, + ... num_inference_steps=25, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(43), + ... ).images[0] + >>> image.save("zimage-inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + if transformer.in_channels == controlnet.config.control_in_dim: + raise ValueError( + "ZImageControlNetInpaintPipeline is not compatible with `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union`, use `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0`." + ) + controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: float | list[float] = 0.75, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to( + device=control_image.device, dtype=control_image.dtype + ) + + init_image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = init_image.shape[-2:] + init_image = init_image * (mask_condition < 0.5) + init_image = retrieve_latents(self.vae.encode(init_image), generator=generator, sample_mode="argmax") + init_image = (init_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + init_image = init_image.unsqueeze(2) + + mask_condition = F.interpolate(1 - mask_condition[:, :1], size=init_image.size()[-2:], mode="nearest").to( + device=control_image.device, dtype=control_image.dtype + ) + mask_condition = mask_condition.unsqueeze(2) + + control_image = torch.cat([control_image, mask_condition, init_image], dim=1) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + control_image_input = control_image.repeat(2, 1, 1, 1, 1) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + control_image_input = control_image + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_image_input, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_img2img.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_img2img.py new file mode 100644 index 0000000000000000000000000000000000000000..ee57f51dd9576b4462e6c5c4d352f91d2045b46d --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_img2img.py @@ -0,0 +1,709 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ZImageImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "A fantasy landscape with mountains and a river, detailed, vibrant colors" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... strength=0.6, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage_img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + r""" + The ZImage pipeline for image-to-image generation. + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`PreTrainedModel`]): + A text encoder model to encode text prompts. + tokenizer ([`AutoTokenizer`]): + A tokenizer to tokenize text prompts. + transformer ([`ZImageTransformer2DModel`]): + A ZImage transformer model to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + # Encode the input image + image = image.to(device=device, dtype=dtype) + if image.shape[1] != num_channels_latents: + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + image_latents = image + + # Handle batch size expansion + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + # Add noise using flow matching scale_noise + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + image: PipelineImageInput = None, + strength: float = 0.6, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for image-to-image generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + strength (`float`, *optional*, defaults to 0.6): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. If not provided, uses the input image height. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. If not provided, uses the input image width. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + # 1. Check inputs and validate strength + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image) + init_image = init_image.to(dtype=torch.float32) + + # Get dimensions from the preprocessed image if not specified + if height is None: + height = init_image.shape[-2] + if width is None: + width = init_image.shape[-1] + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + + # Calculate latent dimensions for image_seq_len + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + + # 6. Adjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(actual_batch_size) + + # 7. Prepare latents from image + latents = self.prepare_latents( + init_image, + latent_timestep, + actual_batch_size, + num_channels_latents, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..e740a48e65ec4bf54623e8fec28a6f8c9439c792 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py @@ -0,0 +1,932 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> # Create a mask (white = inpaint, black = preserve) + >>> import numpy as np + >>> from PIL import Image + + >>> mask = np.zeros((1024, 1024), dtype=np.uint8) + >>> mask[256:768, 256:768] = 255 # Inpaint center region + >>> mask_image = Image.fromarray(mask) + + >>> prompt = "A beautiful lake with mountains in the background" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask_image, + ... strength=1.0, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage_inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + r""" + The ZImage pipeline for inpainting. + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`PreTrainedModel`]): + A text encoder model to encode text prompts. + tokenizer ([`AutoTokenizer`]): + A tokenizer to tokenize text prompts. + transformer ([`ZImageTransformer2DModel`]): + A ZImage transformer model to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + ): + """Prepare mask and masked image latents for inpainting. + + Args: + mask: Binary mask tensor where 1 = inpaint region, 0 = preserve region. + masked_image: Original image with masked regions zeroed out. + batch_size: Number of images to generate. + height: Output image height. + width: Output image width. + dtype: Data type for the tensors. + device: Device to place tensors on. + generator: Random generator for reproducibility. + + Returns: + Tuple of (mask, masked_image_latents) prepared for the denoising loop. + """ + # Calculate latent dimensions + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + # Resize mask to latent dimensions + mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width), mode="nearest") + mask = mask.to(device=device, dtype=dtype) + + # Encode masked image to latents + masked_image = masked_image.to(device=device, dtype=dtype) + if isinstance(generator, list): + masked_image_latents = [ + retrieve_latents(self.vae.encode(masked_image[i : i + 1]), generator=generator[i]) + for i in range(masked_image.shape[0]) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + # Apply VAE scaling + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # Expand for batch size + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + return mask, masked_image_latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + """Prepare latents for inpainting, returning noise and image_latents for blending. + + Returns: + Tuple of (latents, noise, image_latents) where: + - latents: Noised image latents for denoising + - noise: The noise tensor used for blending + - image_latents: Clean image latents for blending + """ + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + # Generate noise for blending even if latents are provided + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Encode image for blending + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + image_latents = torch.cat([image_latents] * (batch_size // image_latents.shape[0]), dim=0) + return latents.to(device=device, dtype=dtype), noise, image_latents + + # Encode the input image + image = image.to(device=device, dtype=dtype) + if image.shape[1] != num_channels_latents: + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + image_latents = image + + # Handle batch size expansion + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + # Generate noise for both initial noising and later blending + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # Add noise using flow matching scale_noise + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + return latents, noise, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if image is None: + raise ValueError("`image` input cannot be undefined for inpainting.") + + if mask_image is None: + raise ValueError("`mask_image` input cannot be undefined for inpainting.") + + if output_type not in ["latent", "pil", "np", "pt"]: + raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: Optional[torch.FloatTensor] = None, + strength: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: str = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for inpainting. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the + mask will be inpainted, black pixels (value 0) will be preserved from the original image. + masked_image_latents (`torch.FloatTensor`, *optional*): + Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image` in the masked region. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. If not provided, uses the input image height. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. If not provided, uses the input image width. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + # 1. Check inputs + self.check_inputs( + prompt=prompt, + image=image, + mask_image=mask_image, + strength=strength, + height=height, + width=width, + output_type=output_type, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Preprocess image and mask + init_image = self.image_processor.preprocess(image) + init_image = init_image.to(dtype=torch.float32) + + # Get dimensions from the preprocessed image if not specified + if height is None: + height = init_image.shape[-2] + if width is None: + width = init_image.shape[-1] + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # Preprocess mask + mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + + # Calculate latent dimensions for image_seq_len + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + + # 6. Adjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(actual_batch_size) + + # 7. Prepare latents from image (returns noise and image_latents for blending) + latents, noise, image_latents = self.prepare_latents( + init_image, + latent_timestep, + actual_batch_size, + num_channels_latents, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + latents, + ) + + # 8. Prepare mask and masked image latents + # Create masked image: preserve only unmasked regions (mask=0) + if masked_image_latents is None: + masked_image = init_image * (mask < 0.5) + else: + masked_image = None # Will use provided masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image if masked_image is not None else init_image, + actual_batch_size, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 9. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + # Inpainting blend: combine denoised latents with original image latents + init_latents_proper = image_latents + + # Re-scale original latents to current noise level for proper blending + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + # Blend: mask=1 for inpaint region (use denoised), mask=0 for preserve region (use original) + latents = (1 - mask) * init_latents_proper + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_omni.py new file mode 100644 index 0000000000000000000000000000000000000000..6d04202162f93e667ab7cb4f48098e651e2551a8 --- /dev/null +++ b/URSA/.venv_ursa/lib/python3.12/site-packages/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -0,0 +1,742 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable + +import PIL +import torch +from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel + +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..flux2.image_processor import Flux2ImageProcessor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageOmniPipeline + + >>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + siglip: Siglip2VisionModel, + siglip_processor: Siglip2ImageProcessorFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ) -> list[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = self.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def prepare_image_latents( + self, + images: list[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = ( + self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + # image_latents = [image_latents] * batch_size + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + def prepare_siglip_embeds( + self, + images: list[torch.Tensor], + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = self.siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 0 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: list[PIL.Image.Image, PIL.Image.Image] | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + sigmas: list[float] | None = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.FloatTensor | None = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`list[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if image is not None and not isinstance(image, list): + image = [image] + num_condition_images = len(image) if image is not None else 0 + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + # 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2 + condition_images = [] + resized_images = [] + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + if height is not None and width is not None: + img = self.image_processor._resize_to_target_area(img, height * width) + else: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + if len(condition_images) > 0: + height = height or image_height + width = width or image_width + + else: + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + condition_latents = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents] + if self.do_classifier_free_guidance: + negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents] + + condition_siglip_embeds = self.prepare_siglip_embeds( + images=resized_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] + if self.do_classifier_free_guidance: + negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] + negative_condition_siglip_embeds = [ + None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds + ] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + condition_latents_model_input = condition_latents + negative_condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + condition_latents_model_input = condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + # Combine condition latents with target latent + current_batch_size = len(latent_model_input_list) + x_combined = [ + condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size) + ] + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size) + ] + + model_out_list = self.transformer( + x=x_combined, + t=timestep_model_input, + cap_feats=prompt_embeds_model_input, + siglip_feats=condition_siglip_embeds_model_input, + image_noise_mask=image_noise_mask, + return_dict=False, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image)